Hướng dẫn giải của Vertex Set Path Composite
Nộp một lời giải chính thức trước khi tự giải là một hành động có thể bị ban.
Nhận xét:
Ta thấy, hàm hợp cần tính không có tính chất giao hoán. Tức là đáp án cho truy vấn ~1~ ~u~ ~v~ ~x~ khác với đáp án cho truy vấn ~1~ ~v~ ~u~ ~x~.
Trước hết ta xem định nghĩa ký hiệu hàm hợp sau: ~g(f(x))~ = ~(g \circ f)(x)~. Nhận thấy hàm hợp cần tính có tính chất kết hợp. Ví dụ ~f_u(f_v(f_t(x)))~ = ~(f_u \circ f_v)(f_t(x))~ = ~a_u * a_v * f_t(x) + a_u * b_t + b_u~.
Tổng quát hơn ~f_k(...f_2(f_1(x)))~ = ~(f_k \circ f_{k - 1} \circ \cdots \circ f_i)~ ~(f_{i - 1} \circ \cdots \circ f_2 \circ f_1)~ ~(x)~. Do đó, ta có thể dùng cây phân đoạn để tính hàm hợp này.
Ta giải quyết bài toán bằng HLD, chúng ta có thể chia đường dẫn từ nút ~u~ đến nút ~v~ thành hai đường dẫn: từ nút ~u~ đến LCA~(u,v)~ và từ LCA~(u,v)~ đến nút ~v~.
Sau đó chúng ta có thể duy trì hai cây phân đoạn. Một cây tính toán hàm hợp tương đương khi di chuyển từ nút ~u~ đến LCA~(u,v)~. Cây còn lại tính toán hàm hợp tương đương khi di chuyển từ LCA~(u,v)~ tới nút ~v~. Cuối cùng, chúng ta tính cả hàm hợp với giá trị ~x~.
Cần lưu ý rằng hàm tại nút LCA~(u,v)~ chỉ nên được tính một lần.
Độ phức tạp về thời gian: ~\mathcal{O}(N \log^2 N)~.
#include <bits/stdc++.h> #define fi first #define se second #define mp make_pair using namespace std; template <typename T1, typename T2> bool mini(T1 &a, T2 b) { if (a > b) {a = b; return true;} return false; } template <typename T1, typename T2> bool maxi(T1 &a, T2 b) { if (a < b) {a = b; return true;} return false; } const int N = 2e5 + 5; const int oo = 1e9; const int mod = 998244353; pair <int, int> operator + (pair <int, int> a, pair <int, int> b) { pair <int, int> res; res.fi = (long long) a.fi * b.fi % mod; res.se = ((long long) a.fi * b.se + a.se) % mod; return res; } struct segtree { int n; vector <pair <int, int>> st; segtree(int _n = 0) { n = _n; st.assign((n << 2) + 5, mp(0, 0)); } void update(int id, int l, int r, int pos, pair <int, int> val) { if (l > pos || r < pos) return; if (l == r) return (void) (st[id] = val); int mid = (l + r) >> 1; update(id << 1, l, mid, pos, val); update(id << 1 | 1, mid + 1, r, pos, val); st[id] = st[id << 1] + st[id << 1 | 1]; } void update(int pos, pair <int, int> val) { update(1, 1, n, pos, val); } pair <int, int> getsum(int id, int l, int r, int u, int v) { if (l > v || r < u) return mp(1, 0); if (l >= u && r <= v) return st[id]; int mid = (l + r) >> 1; pair <int, int> lnode = getsum(id << 1, l, mid, u, v); pair <int, int> rnode = getsum(id << 1 | 1, mid + 1, r, u, v); return lnode + rnode; } pair <int, int> getsum(int l, int r) { if (l > r) swap(l, r); return getsum(1, 1, n, l, r); } }; segtree st, rev; vector <int> adj[N]; pair <int, int> a[N]; int revpos[N]; int par[N]; int pos[N]; int sz[N]; int d[N]; int r[N]; int n,q,timer; void dfs(int u, int p = -1) { par[u] = p; sz[u] = 1; for (int v : adj[u]) if (v != p) { d[v] = d[u] + 1; dfs(v, u); sz[u] += sz[v]; } } void build(int u, int root, int p = -1) { pos[u] = ++timer; r[u] = root; int hv = 0; for (int v : adj[u]) if (v != p && sz[hv] < sz[v]) hv = v; if (hv == 0) return; build(hv, root, u); for (int v : adj[u]) if (v != p && v != hv) build(v, v, u); } int getlca(int u, int v) { while (r[u] != r[v]) { if (d[r[u]] < d[r[v]]) swap(u, v); u = par[r[u]]; } if (d[u] > d[v]) swap(u, v); return u; } int main() { ios_base::sync_with_stdio(0); cin.tie(0); cin >> n >> q; for (int i = 1; i <= n; i++) cin >> a[i].fi >> a[i].se; for (int i = 1; i < n; i++) { int u,v; cin >> u >> v; adj[u].push_back(v); adj[v].push_back(u); } dfs(1); build(1, 1); st = rev = segtree(n); for (int i = 1; i <= n; i++) { revpos[i] = n - pos[i] + 1; st.update(pos[i], a[i]); rev.update(revpos[i], a[i]); } while (q--) { int op; cin >> op; if (op == 0) { int u,x,y; cin >> u >> x >> y; st.update(pos[u], mp(x, y)); rev.update(revpos[u], mp(x, y)); } else { int u,v,x; cin >> u >> v >> x; int p = getlca(u, v); vector <pair <int, int>> f, g; while (d[r[u]] > d[r[p]]) { g.push_back(st.getsum(pos[r[u]], pos[u])); u = par[r[u]]; } g.push_back(st.getsum(pos[p], pos[u])); reverse(g.begin(), g.end()); while (d[r[v]] > d[r[p]]) { f.push_back(rev.getsum(revpos[r[v]], revpos[v])); v = par[r[v]]; } if (v != p) f.push_back(rev.getsum(revpos[p] - 1, revpos[v])); pair <int, int> res = mp(1, 0); for (pair <int, int> x : f) res = res + x; for (pair <int, int> x : g) res = res + x; cout << ((long long) res.fi * x + res.se) % mod << "\n"; } } return 0; }
Bình luận