Editorial for Lại là bài truy vấn đường đi
Submitting an official solution before solving the problem yourself is a bannable offence.
Chúng ta sẽ sử dụng LCA, HLD + segment tree (cây phân đoạn) để làm bài này.
Các bạn nếu chưa biết HLD có thể tìm đọc tại link sau đây: Link
Với mỗi truy vấn ~2~, ta gọi ~c = lca(a, b)~ là tổ tiên chung nhỏ nhất của ~a~ và ~b~. Dễ dàng chứng minh rằng đáp án là ~max(f(a, c), f(b, c))~ với f(x, y) là giá trị lớn nhất của một đỉnh trong đường đi từ ~x~ đến một tổ tiên ~y~ của nó.
Để tính được giá trị ~f~, ta sẽ sử dụng HLD để chia cây thành các đường đi sao cho với mỗi đỉnh cần tối đa ~O(log_2(n))~ đường đi để lên được gốc.
Với mỗi đường đi, ta lưu segment tree có chức năng update một đỉnh ở đường đi và truy vấn một đoạn trên đường đi đó.
Khi truy vấn ~f(x, y)~, ta làm như sau:
Nếu ~x~ và ~y~ nằm trên những đường đi khác nhau, ta lấy cập nhật giá trị lớn nhất của một đỉnh trên đường đi từ ~x~ đến ~z~ - đỉnh đầu tiên của đường đi chứa đỉnh ~x~ - vào đáp án, và nhảy ~x~ lên đến cha của ~z~.
Nếu ~x~ và ~y~ nằm trên những đường đi giống nhau, ta cập nhật giá trị lớn nhất của một đỉnh trên đường đi từ ~x~ đến ~y~ vào đáp án và kết thúc.
Tổng độ phức tạp: ~O(q * log_2(n)^2)~ - vì
Với truy vấn ~1~, ta chỉ cần update trên duy nhất một đường đi, và độ phức tạp là ~O(log_2(n))~.
Với truy vấn ~2~, ta cần truy cập tối đa ~O(log_2(n))~ đường đi (tính chất của HLD), và mỗi truy vấn ta mất ~O(log_2(n))~ để truy vấn.
#include<bits/stdc++.h> using namespace std; #define int long long #define fi first #define se second #define pb push_back typedef pair<int, int> ii; typedef pair<ii, int> iii; typedef pair<ii, ii> iiii; const int N = 5e5 + 5; const int oo = 1e18 + 7, mod = 1e9 + 7; int n, q; int v[N]; vector<int> Adj[N]; int sub[N];// size of subtree of nodes int cnt_chain; int no_chain[N], pos_in_chain[N];// for nodes int chain_head[N], sz[N];// for chains vector<int> contents[N];// nodes of hld vector<int> IT[N];// segment tree for each chain int depth[N], anc[N][20];// for LCA -> queries void pre_dfs(int u, int p){ anc[u][0] = p; for(int i = 1; i <= 19; i++) anc[u][i] = anc[anc[u][i - 1]][i - 1]; sub[u] = 1; for(auto v : Adj[u]){ if(v == p) continue; depth[v] = depth[u] + 1; pre_dfs(v, u); sub[u] += sub[v]; } } void hld(int u, int p, bool nw){ if(nw){ cnt_chain++; no_chain[u] = cnt_chain; pos_in_chain[u] = 0; chain_head[cnt_chain] = u; } else{ no_chain[u] = no_chain[p]; pos_in_chain[u] = pos_in_chain[p] + 1; } contents[no_chain[u]].pb(u); sz[no_chain[u]]++; //cout << u << " " << no_chain[u] << " " << pos_in_chain[u] << "\n"; ii mx = {-1, -1}; for(auto v : Adj[u]){ if(v == p) continue; // cout << v << " " << sub[v] << "\n"; mx = max(mx, {sub[v], v}); } if(mx.fi == -1) return; for(auto v : Adj[u]){ if(v == p) continue; hld(v, u, (v != mx.se)); } } // both of this are for fastening the segtree as we have only a second int pos_seg[N];// position in the segtree of the chain, when update we just for from that position upwards vector<int> vis[N];// most of the time, you want to query a prefix except for O(q) time, this will help void prep(int chain, int id, int l, int r){ if(l == r){ vis[contents[chain][l]].pb(id); pos_seg[contents[chain][l]] = id; return; } int mid = (l + r) >> 1; prep(chain, id << 1, l, mid); for(int i = mid + 1; i <= r; i++) vis[contents[chain][i]].pb(id << 1); prep(chain, id << 1 | 1, mid + 1, r); } void upd(int node, int val){ int temp = pos_seg[node]; IT[no_chain[node]][temp] = val; //cout << "OK " << node << " " << no_chain[node] << " " << temp << "\n"; //return; temp >>= 1; for(; temp; temp >>= 1){ IT[no_chain[node]][temp] = max(IT[no_chain[node]][temp << 1], IT[no_chain[node]][temp << 1 | 1]); } } int get(int node){// get from this node to the top of the chain int ans = 0; for(auto it : vis[node]){ ans = max(ans, IT[no_chain[node]][it]); } return ans; } int get2(int chain, int id, int l, int r, int L, int R){ if(l > R || r < L) return 0; if(l >= L && r <= R){ // cout << "OKOK " << chain << ' ' << id << ' ' << L << ' ' << R << ' ' << IT[chain][id] << '\n'; return IT[chain][id]; } int mid = (l + r) >> 1; return max(get2(chain, id << 1, l, mid, L, R), get2(chain, id << 1 | 1, mid + 1, r, L, R)); } int lca(int x, int y){ if(depth[x] > depth[y]) swap(x, y); int diff = depth[y] - depth[x]; for(int i = 19; i >= 0; i--) if(diff & (1LL << i)) y = anc[y][i]; if(x == y) return x; for(int i = 19; i >= 0; i--){ if(anc[x][i] != anc[y][i]){ x = anc[x][i], y = anc[y][i]; } } return anc[x][0]; } int climb(int x, int y){// climb from x to y int ans = 0; while(no_chain[x] != no_chain[y]){ ans = max(ans, get(x)); x = anc[chain_head[no_chain[x]]][0]; } ans = max(ans, get2(no_chain[x], 1, 0, sz[no_chain[x]] - 1, pos_in_chain[y], pos_in_chain[x])); return ans; } void process(){ cin >> n >> q; for(int i = 1; i <= n; i++) cin >> v[i]; for(int i = 1; i < n; i++){ int x, y; cin >> x >> y; Adj[x].pb(y); Adj[y].pb(x); } pre_dfs(1, 1); // return; hld(1, 1, 1); // return; for(int i = 1; i <= cnt_chain; i++){ IT[i].resize((sz[i] << 2) + 5); prep(i, 1, 0, sz[i] - 1); } // return; for(int i = 1; i <= n; i++) upd(i, v[i]); //return; while(q--){ int que; cin >> que; if(que == 1){ int node, val; cin >> node >> val; v[node] = val; upd(node, v[node]); } else{ int a, b; cin >> a >> b; int c = lca(a, b); // cout << c << "\n"; cout << max(climb(a, c), climb(b, c)) << " "; } } } signed main(){ ios_base::sync_with_stdio(0); cin.tie(0); process(); }
Comments