Hướng dẫn giải của Lại là bài truy vấn đường đi
Nộp một lời giải chính thức trước khi tự giải là một hành động có thể bị ban.
Chúng ta sẽ sử dụng LCA, HLD + segment tree (cây phân đoạn) để làm bài này.
Các bạn nếu chưa biết HLD có thể tìm đọc tại link sau đây: Link
Với mỗi truy vấn ~2~, ta gọi ~c = lca(a, b)~ là tổ tiên chung nhỏ nhất của ~a~ và ~b~. Dễ dàng chứng minh rằng đáp án là ~max(f(a, c), f(b, c))~ với f(x, y) là giá trị lớn nhất của một đỉnh trong đường đi từ ~x~ đến một tổ tiên ~y~ của nó.
Để tính được giá trị ~f~, ta sẽ sử dụng HLD để chia cây thành các đường đi sao cho với mỗi đỉnh cần tối đa ~O(log_2(n))~ đường đi để lên được gốc.
Với mỗi đường đi, ta lưu segment tree có chức năng update một đỉnh ở đường đi và truy vấn một đoạn trên đường đi đó.
Khi truy vấn ~f(x, y)~, ta làm như sau:
Nếu ~x~ và ~y~ nằm trên những đường đi khác nhau, ta lấy cập nhật giá trị lớn nhất của một đỉnh trên đường đi từ ~x~ đến ~z~ - đỉnh đầu tiên của đường đi chứa đỉnh ~x~ - vào đáp án, và nhảy ~x~ lên đến cha của ~z~.
Nếu ~x~ và ~y~ nằm trên những đường đi giống nhau, ta cập nhật giá trị lớn nhất của một đỉnh trên đường đi từ ~x~ đến ~y~ vào đáp án và kết thúc.
Tổng độ phức tạp: ~O(q * log_2(n)^2)~ - vì
Với truy vấn ~1~, ta chỉ cần update trên duy nhất một đường đi, và độ phức tạp là ~O(log_2(n))~.
Với truy vấn ~2~, ta cần truy cập tối đa ~O(log_2(n))~ đường đi (tính chất của HLD), và mỗi truy vấn ta mất ~O(log_2(n))~ để truy vấn.
#include<bits/stdc++.h> using namespace std; #define int long long #define fi first #define se second #define pb push_back typedef pair<int, int> ii; typedef pair<ii, int> iii; typedef pair<ii, ii> iiii; const int N = 5e5 + 5; const int oo = 1e18 + 7, mod = 1e9 + 7; int n, q; int v[N]; vector<int> Adj[N]; int sub[N];// size of subtree of nodes int cnt_chain; int no_chain[N], pos_in_chain[N];// for nodes int chain_head[N], sz[N];// for chains vector<int> contents[N];// nodes of hld vector<int> IT[N];// segment tree for each chain int depth[N], anc[N][20];// for LCA -> queries void pre_dfs(int u, int p){ anc[u][0] = p; for(int i = 1; i <= 19; i++) anc[u][i] = anc[anc[u][i - 1]][i - 1]; sub[u] = 1; for(auto v : Adj[u]){ if(v == p) continue; depth[v] = depth[u] + 1; pre_dfs(v, u); sub[u] += sub[v]; } } void hld(int u, int p, bool nw){ if(nw){ cnt_chain++; no_chain[u] = cnt_chain; pos_in_chain[u] = 0; chain_head[cnt_chain] = u; } else{ no_chain[u] = no_chain[p]; pos_in_chain[u] = pos_in_chain[p] + 1; } contents[no_chain[u]].pb(u); sz[no_chain[u]]++; //cout << u << " " << no_chain[u] << " " << pos_in_chain[u] << "\n"; ii mx = {-1, -1}; for(auto v : Adj[u]){ if(v == p) continue; // cout << v << " " << sub[v] << "\n"; mx = max(mx, {sub[v], v}); } if(mx.fi == -1) return; for(auto v : Adj[u]){ if(v == p) continue; hld(v, u, (v != mx.se)); } } // both of this are for fastening the segtree as we have only a second int pos_seg[N];// position in the segtree of the chain, when update we just for from that position upwards vector<int> vis[N];// most of the time, you want to query a prefix except for O(q) time, this will help void prep(int chain, int id, int l, int r){ if(l == r){ vis[contents[chain][l]].pb(id); pos_seg[contents[chain][l]] = id; return; } int mid = (l + r) >> 1; prep(chain, id << 1, l, mid); for(int i = mid + 1; i <= r; i++) vis[contents[chain][i]].pb(id << 1); prep(chain, id << 1 | 1, mid + 1, r); } void upd(int node, int val){ int temp = pos_seg[node]; IT[no_chain[node]][temp] = val; //cout << "OK " << node << " " << no_chain[node] << " " << temp << "\n"; //return; temp >>= 1; for(; temp; temp >>= 1){ IT[no_chain[node]][temp] = max(IT[no_chain[node]][temp << 1], IT[no_chain[node]][temp << 1 | 1]); } } int get(int node){// get from this node to the top of the chain int ans = 0; for(auto it : vis[node]){ ans = max(ans, IT[no_chain[node]][it]); } return ans; } int get2(int chain, int id, int l, int r, int L, int R){ if(l > R || r < L) return 0; if(l >= L && r <= R){ // cout << "OKOK " << chain << ' ' << id << ' ' << L << ' ' << R << ' ' << IT[chain][id] << '\n'; return IT[chain][id]; } int mid = (l + r) >> 1; return max(get2(chain, id << 1, l, mid, L, R), get2(chain, id << 1 | 1, mid + 1, r, L, R)); } int lca(int x, int y){ if(depth[x] > depth[y]) swap(x, y); int diff = depth[y] - depth[x]; for(int i = 19; i >= 0; i--) if(diff & (1LL << i)) y = anc[y][i]; if(x == y) return x; for(int i = 19; i >= 0; i--){ if(anc[x][i] != anc[y][i]){ x = anc[x][i], y = anc[y][i]; } } return anc[x][0]; } int climb(int x, int y){// climb from x to y int ans = 0; while(no_chain[x] != no_chain[y]){ ans = max(ans, get(x)); x = anc[chain_head[no_chain[x]]][0]; } ans = max(ans, get2(no_chain[x], 1, 0, sz[no_chain[x]] - 1, pos_in_chain[y], pos_in_chain[x])); return ans; } void process(){ cin >> n >> q; for(int i = 1; i <= n; i++) cin >> v[i]; for(int i = 1; i < n; i++){ int x, y; cin >> x >> y; Adj[x].pb(y); Adj[y].pb(x); } pre_dfs(1, 1); // return; hld(1, 1, 1); // return; for(int i = 1; i <= cnt_chain; i++){ IT[i].resize((sz[i] << 2) + 5); prep(i, 1, 0, sz[i] - 1); } // return; for(int i = 1; i <= n; i++) upd(i, v[i]); //return; while(q--){ int que; cin >> que; if(que == 1){ int node, val; cin >> node >> val; v[node] = val; upd(node, v[node]); } else{ int a, b; cin >> a >> b; int c = lca(a, b); // cout << c << "\n"; cout << max(climb(a, c), climb(b, c)) << " "; } } } signed main(){ ios_base::sync_with_stdio(0); cin.tie(0); process(); }
Bình luận