Editorial for Lại là bài truy vấn đường đi


Remember to use this editorial only when stuck, and not to copy-paste code from it. Please be respectful to the problem author and editorialist.
Submitting an official solution before solving the problem yourself is a bannable offence.

Chúng ta sẽ sử dụng LCA, HLD + segment tree (cây phân đoạn) để làm bài này.

Các bạn nếu chưa biết HLD có thể tìm đọc tại link sau đây: Link

Với mỗi truy vấn ~2~, ta gọi ~c = lca(a, b)~ là tổ tiên chung nhỏ nhất của ~a~ và ~b~. Dễ dàng chứng minh rằng đáp án là ~max(f(a, c), f(b, c))~ với f(x, y) là giá trị lớn nhất của một đỉnh trong đường đi từ ~x~ đến một tổ tiên ~y~ của nó.

Để tính được giá trị ~f~, ta sẽ sử dụng HLD để chia cây thành các đường đi sao cho với mỗi đỉnh cần tối đa ~O(log_2(n))~ đường đi để lên được gốc.

Với mỗi đường đi, ta lưu segment tree có chức năng update một đỉnh ở đường đi và truy vấn một đoạn trên đường đi đó.

Khi truy vấn ~f(x, y)~, ta làm như sau:

  • Nếu ~x~ và ~y~ nằm trên những đường đi khác nhau, ta lấy cập nhật giá trị lớn nhất của một đỉnh trên đường đi từ ~x~ đến ~z~ - đỉnh đầu tiên của đường đi chứa đỉnh ~x~ - vào đáp án, và nhảy ~x~ lên đến cha của ~z~.

  • Nếu ~x~ và ~y~ nằm trên những đường đi giống nhau, ta cập nhật giá trị lớn nhất của một đỉnh trên đường đi từ ~x~ đến ~y~ vào đáp án và kết thúc.

Tổng độ phức tạp: ~O(q * log_2(n)^2)~ - vì

  • Với truy vấn ~1~, ta chỉ cần update trên duy nhất một đường đi, và độ phức tạp là ~O(log_2(n))~.

  • Với truy vấn ~2~, ta cần truy cập tối đa ~O(log_2(n))~ đường đi (tính chất của HLD), và mỗi truy vấn ta mất ~O(log_2(n))~ để truy vấn.

#include<bits/stdc++.h>
using namespace std;

#define int long long
#define fi first
#define se second
#define pb push_back

typedef pair<int, int> ii;
typedef pair<ii, int> iii;
typedef pair<ii, ii> iiii;

const int N = 5e5 + 5;

const int oo = 1e18 + 7, mod = 1e9 + 7;

int n, q;
int v[N];

vector<int> Adj[N];

int sub[N];// size of subtree of nodes

int cnt_chain;
int no_chain[N], pos_in_chain[N];// for nodes
int chain_head[N], sz[N];// for chains

vector<int> contents[N];// nodes of hld
vector<int> IT[N];// segment tree for each chain

int depth[N], anc[N][20];// for LCA -> queries

void pre_dfs(int u, int p){
    anc[u][0] = p;
    for(int i = 1; i <= 19; i++) anc[u][i] = anc[anc[u][i - 1]][i - 1];
    sub[u] = 1;
    for(auto v : Adj[u]){
        if(v == p) continue;
        depth[v] = depth[u] + 1;
        pre_dfs(v, u);
        sub[u] += sub[v];
    }
}

void hld(int u, int p, bool nw){
    if(nw){
        cnt_chain++;
        no_chain[u] = cnt_chain;
        pos_in_chain[u] = 0;
        chain_head[cnt_chain] = u;
    }
    else{
        no_chain[u] = no_chain[p];
        pos_in_chain[u] = pos_in_chain[p] + 1;
    }
    contents[no_chain[u]].pb(u);
    sz[no_chain[u]]++;
    //cout << u << " " << no_chain[u] << " " << pos_in_chain[u] << "\n";
    ii mx = {-1, -1};
    for(auto v : Adj[u]){
        if(v == p) continue;
    //  cout << v << " " << sub[v] << "\n";
        mx = max(mx, {sub[v], v});
    }
    if(mx.fi == -1) return;
    for(auto v : Adj[u]){
        if(v == p) continue;
        hld(v, u, (v != mx.se));
    }
}

// both of this are for fastening the segtree as we have only a second
int pos_seg[N];// position in the segtree of the chain, when update we just for from that position upwards
vector<int> vis[N];// most of the time, you want to query a prefix except for O(q) time, this will help 

void prep(int chain, int id, int l, int r){
    if(l == r){
        vis[contents[chain][l]].pb(id);
        pos_seg[contents[chain][l]] = id;
        return;
    }
    int mid = (l + r) >> 1;
    prep(chain, id << 1, l, mid);
    for(int i = mid + 1; i <= r; i++) vis[contents[chain][i]].pb(id << 1);
    prep(chain, id << 1 | 1, mid + 1, r);
}

void upd(int node, int val){
    int temp = pos_seg[node];
    IT[no_chain[node]][temp] = val;
//cout << "OK " << node << " " << no_chain[node] << " " << temp << "\n";
    //return;
    temp >>= 1;
    for(; temp; temp >>= 1){
        IT[no_chain[node]][temp] = max(IT[no_chain[node]][temp << 1], IT[no_chain[node]][temp << 1 | 1]);
    }
}

int get(int node){// get from this node to the top of the chain
    int ans = 0;
    for(auto it : vis[node]){
        ans = max(ans, IT[no_chain[node]][it]);
    }
    return ans;
}

int get2(int chain, int id, int l, int r, int L, int R){
    if(l > R || r < L) return 0;
    if(l >= L && r <= R){
    //  cout << "OKOK " << chain << ' ' << id << ' ' << L << ' ' << R << ' ' << IT[chain][id] << '\n';
        return IT[chain][id];
    }
    int mid = (l + r) >> 1;
    return max(get2(chain, id << 1, l, mid, L, R), get2(chain, id << 1 | 1, mid + 1, r, L, R));
}

int lca(int x, int y){
    if(depth[x] > depth[y]) swap(x, y);
    int diff = depth[y] - depth[x];
    for(int i = 19; i >= 0; i--) if(diff & (1LL << i)) y = anc[y][i];
    if(x == y) return x;
    for(int i = 19; i >= 0; i--){
        if(anc[x][i] != anc[y][i]){
            x = anc[x][i], y = anc[y][i];
        }
    }
    return anc[x][0];
}

int climb(int x, int y){// climb from x to y
    int ans = 0;
    while(no_chain[x] != no_chain[y]){
        ans = max(ans, get(x));
        x = anc[chain_head[no_chain[x]]][0];
    }
    ans = max(ans, get2(no_chain[x], 1, 0, sz[no_chain[x]] - 1, pos_in_chain[y], pos_in_chain[x]));
    return ans;
}

void process(){
    cin >> n >> q;
    for(int i = 1; i <= n; i++) cin >> v[i];
    for(int i = 1; i < n; i++){
        int x, y;
        cin >> x >> y;
        Adj[x].pb(y);
        Adj[y].pb(x);
    }
    pre_dfs(1, 1);
//  return;
    hld(1, 1, 1);
//  return;
    for(int i = 1; i <= cnt_chain; i++){
        IT[i].resize((sz[i] << 2) + 5);
        prep(i, 1, 0, sz[i] - 1);
    }
//  return;
    for(int i = 1; i <= n; i++) upd(i, v[i]);
    //return;
    while(q--){
        int que;
        cin >> que;
        if(que == 1){
            int node, val;
            cin >> node >> val;
            v[node] = val;
            upd(node, v[node]);
        }
        else{
            int a, b;
            cin >> a >> b;
            int c = lca(a, b);
    //      cout << c << "\n";
            cout << max(climb(a, c), climb(b, c)) << " ";
        }
    }
}

signed main(){
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    process();
}

Comments

Please read the guidelines before commenting.


There are no comments at the moment.