Editorial for Bedao Grand Contest 08 - DIVTREE
Remember to use this editorial only when stuck, and not to copy-paste code from it. Please be respectful to the problem author and editorialist.
Submitting an official solution before solving the problem yourself is a bannable offence.
Submitting an official solution before solving the problem yourself is a bannable offence.
Code mẫu
#include <bits/stdc++.h> #define fi first #define se second #define io ios_base::sync_with_stdio(0), cin.tie(0), cout.tie(0); using namespace std; const int MAXN = 2e5 + 7, MOD = 1e9 + 7; typedef pair<int,int> pii; int n, q, inv[MAXN], mpr[MAXN] = {}, sz[MAXN] = {}, ans = 1, r[MAXN]; int cnt = 0, tin[MAXN], tout[MAXN], pos[MAXN], d[MAXN] = {}; vector<pii> fact[MAXN]; vector<int> adj[MAXN]; int pw(int cs, int sm) { if (sm == 0) return 1; int hpw = pw(cs, sm >> 1); hpw = 1LL*hpw*hpw % MOD; if (sm & 1) hpw = 1LL*hpw*cs % MOD; return hpw; } void sieve() { inv[1] = 1; for (int i = 2; i < MAXN; ++i) { inv[i] = pw(i, MOD - 2); if (!mpr[i]) { mpr[i] = i; for (int j = i + i; j < MAXN; j += i) if (!mpr[j]) mpr[j] = i; } } } void init(int u, int p) { tin[u] = ++cnt; pos[cnt] = u; sz[u] = 1; for (auto v: adj[u]) if (v != p) init(v, u), sz[u] += sz[v]; tout[u] = cnt; } void upd(int u, int c) { for (auto it: fact[u]) { int v = it.fi, w = it.se; ans = 1LL*ans*inv[d[v] + 1] % MOD; d[v] += c*w; ans = 1LL*ans*(d[v] + 1) % MOD; } } void calc(int u, int p, int keep) { int mx = 0, nxt = -1; for (auto v: adj[u]) if (v != p && sz[v] > mx) mx = sz[v], nxt = v; // if (u == 1) cout << nxt << '\n'; for (auto v: adj[u]) if (v != p && v != nxt) calc(v, u, 0); if (nxt != -1) calc(nxt, u, 1); for (auto v: adj[u]) if (v != p && v != nxt) for (int i = tin[v]; i <= tout[v]; ++i) upd(pos[i], 1); upd(u, 1); r[u] = ans; if (!keep) { for (int i = tin[u]; i <= tout[u]; ++i) upd(pos[i], -1); } } int main() { freopen("DIVTREE.INP","r",stdin); freopen("DIVTREE.OUT","w",stdout); io; sieve(); cin >> n >> q; for (int i = 1; i <= n; ++i) { int x; cin >> x; while (x > 1) { int p = mpr[x], cnt = 1; x /= p; while (mpr[x] == p) x /= p, ++cnt; fact[i].push_back({p, cnt}); } } for (int i = 1; i < n; ++i) { int u, v; cin >> u >> v; adj[u].push_back(v); adj[v].push_back(u); } init(1, 0); calc(1, 0, 0); while (q--) { int x; cin >> x; cout << r[x] << ' '; } return 0; }
Comments