Hướng dẫn giải của Bedao Grand Contest 12 - ELECTRIC
Nộp một lời giải chính thức trước khi tự giải là một hành động có thể bị ban.
Gọi ~sub[u]~ là tổng nhu cầu của các thành phố được quản lý bởi ~u~. Dễ thấy nếu xây nhà máy điện ở thủ đô ~1~ thì tổng lượng điện thất thoát là ~S_1=\sum_{v=2}^N sub[v]~. Để cập nhật ~sub~ nhanh chóng ta làm như sau: Gọi ~tin[u]~ và ~tout[u]~ lần lượt là thời điểm vào và thoát khi ta thực hiện ~dfs~ xuất phát từ ~1~, gọi một mảng ~b~ sao cho ~b[tin[u]]=a[u]~. Khi đó với mỗi ~u~ thì cây con gốc ~u~ được biểu diễn bởi đoạn liên tiếp từ ~tin[u]~ tới ~tout[u]~ trên mảng ~b~. Các thao tác cập nhật cây ta sẽ sử dụng segment tree cùng với lazy propagation trên các đoạn tương ứng.
Xét một đỉnh ~u~ mà nếu xây nhà máy ở đó thì có lượng điện thất thoát là ~S_u~. Xét ~v~ là các đỉnh được quản lý trực tiếp bởi ~u~ (~u~ kề ~v~ trên cây), dễ thấy rằng ~S_v=S_u+(sub[1]-sub[v]\cdot 2)\cdot d(u,v)~. Ta có ~S_v\lt S_u~ khi và chỉ khi ~sub[v]\cdot 2\gt sub[1]~ và chỉ có nhiều nhất một đỉnh ~v~ thỏa mãn điều kiện này. Từ đây ta có một thuật toán đúng là xuất phát từ đỉnh ~u=1~, lặp lại việc tìm một đỉnh ~v~ được quản lý trực tiếp bởi ~u~ sao cho ~sub[v]\cdot 2\gt sub[1]~ và thay ~u~ bởi ~v~ cho đến không tìm được ~v~ nào thỏa mãn. Đỉnh ~u~ cuối cùng sẽ là vị trí tối ưu.
Để tăng tốc thuật toán này ta có một nhận xét: Gọi ~p~ là vị trí nhỏ nhất sao cho ~\sum_{i=1}^p b[i] \cdot 2 \gt sub[1]~ và ~v~ sao cho ~tin[v]=p~, khi đó ~v~ phải bị quản lý bởi đỉnh ~u~ tối ưu.
Chứng minh: Nếu ~tout[u]\lt p~ hoặc ~tin[u]>p~, rõ ràng khi đó ~sub[u]\cdot 2=2\cdot \sum_{i=tin[u]}^{tout[u]}b[i]\le sub[1]~. Vì vậy ~tin[u]\le p\le tout[u]~, hay ~u~ quản lý ~v~.
Để tìm ~p~ ta sử dụng kĩ thuật tìm kiếm nhị phân trên segment tree. Xuất phát từ ~v~, để tìm ~u~ ta sử dụng kĩ thuật tương tự như bài toán LCA, tìm tổ tiên ~x~ của ~v~ cao nhất mà ~sub[x]\cdot 2\le sub[1]~, khi đó đáp án là ~x~ nếu ~x=1~ hoặc đỉnh cha của ~x~ nếu ~x\ne 1~.
Code mẫu
#include <bits/stdc++.h> using namespace std; using ll = long long; const int MAXN = 200005; const int LOG = 18; struct segmentTree { #define il i * 2 #define ir i * 2 + 1 vector <ll> val, tag; int lo, hi, N; ll res; void init(int N) { this->N = N; val.assign(4 * N, 0); tag.assign(4 * N, 0); } void add(int i, int l, int r, ll v) { if (l >= lo && r <= hi) { val[i] += v * (r - l + 1); tag[i] += v; return; } int m = (l + r) / 2; if (tag[i]) { val[il] += tag[i] * (m - l + 1); tag[il] += tag[i]; val[ir] += tag[i] * (r - m); tag[ir] += tag[i]; tag[i] = 0; } if (m >= lo) add(il, l, m, v); if (m < hi) add(ir, m + 1, r, v); val[i] = val[il] + val[ir]; } void add(int l, int r, ll v) { lo = l; hi = r; add(1, 1, N, v); } void get(int i, int l, int r) { if (l >= lo && r <= hi) return void(res += val[i]); int m = (l + r) / 2; if (tag[i]) { val[il] += tag[i] * (m - l + 1); tag[il] += tag[i]; val[ir] += tag[i] * (r - m); tag[ir] += tag[i]; tag[i] = 0; } if (m >= lo) get(il, l, m); if (m < hi) get(ir, m + 1, r); } ll get(int l, int r) { if (l > r) return 0; lo = l; hi = r; res = 0; get(1, 1, N); return res; } int find(ll v) { if (val[1] < v) return -1; int i = 1, l = 1, r = N; while (l < r) { int m = (l + r) / 2; if (tag[i]) { val[il] += tag[i] * (m - l + 1); tag[il] += tag[i]; val[ir] += tag[i] * (r - m); tag[ir] += tag[i]; tag[i] = 0; } if (val[il] >= v) { i = il; r = m; } else { v -= val[il]; i = ir; l = m + 1; } } return l; } }; int par[LOG][MAXN], timer; int tin[MAXN], tout[MAXN]; int bel[MAXN]; vector <int> adj[MAXN]; segmentTree ST; void DFSpre(int u) { for (int i = 1; i < LOG; i++) par[i][u] = par[i - 1][par[i - 1][u]]; bel[tin[u] = ++timer] = u; for (int v : adj[u]) if (v != par[0][u]) { par[0][v] = u; DFSpre(v); } tout[u] = timer; } ll getSum(int u) { return ST.get(tin[u], tout[u]); } int findCen() { ll half = (getSum(1) + 1) / 2; int u = bel[ST.find(half)]; if (getSum(u) >= half) return u; for (int i = LOG - 1; i >= 0; i--) if (par[i][u] && getSum(par[i][u]) < half) u = par[i][u]; return par[0][u]; } int main() { cin.tie(0)->sync_with_stdio(0); int N, Q; cin >> N >> Q; for (int i = 1; i < N; i++) { int u, v, w; cin >> u >> v >> w; adj[u].push_back(v); adj[v].push_back(u); } DFSpre(1); ST.init(N); for (int i = 1; i <= N; i++) { int c; cin >> c; ST.add(tin[i], tin[i], c); } while (Q--) { int cmd, u, w; cin >> cmd >> u >> w; if (cmd == 1) ST.add(tin[u], tin[u], w); else if (cmd == 2) ST.add(tin[u], tin[u], -w); else if (cmd == 3) ST.add(tin[u], tout[u], w); else if (cmd == 4) ST.add(tin[u], tout[u], -w); cout << findCen() << '\n'; } }
Bình luận