Hướng dẫn giải của Vertex Set Path Composite


Chỉ dùng lời giải này khi không có ý tưởng, và đừng copy-paste code từ lời giải này. Hãy tôn trọng người ra đề và người viết lời giải.
Nộp một lời giải chính thức trước khi tự giải là một hành động có thể bị ban.

Nhận xét:

  • Ta thấy, hàm hợp cần tính không có tính chất giao hoán. Tức là đáp án cho truy vấn ~1~ ~u~ ~v~ ~x~ khác với đáp án cho truy vấn ~1~ ~v~ ~u~ ~x~.

  • Trước hết ta xem định nghĩa ký hiệu hàm hợp sau: ~g(f(x))~ = ~(g \circ f)(x)~. Nhận thấy hàm hợp cần tính có tính chất kết hợp. Ví dụ ~f_u(f_v(f_t(x)))~ = ~(f_u \circ f_v)(f_t(x))~ = ~a_u * a_v * f_t(x) + a_u * b_t + b_u~.

  • Tổng quát hơn ~f_k(...f_2(f_1(x)))~ = ~(f_k \circ f_{k - 1} \circ \cdots \circ f_i)~ ~(f_{i - 1} \circ \cdots \circ f_2 \circ f_1)~ ~(x)~. Do đó, ta có thể dùng cây phân đoạn để tính hàm hợp này.

Ta giải quyết bài toán bằng HLD, chúng ta có thể chia đường dẫn từ nút ~u~ đến nút ~v~ thành hai đường dẫn: từ nút ~u~ đến LCA~(u,v)~ và từ LCA~(u,v)~ đến nút ~v~.

Sau đó chúng ta có thể duy trì hai cây phân đoạn. Một cây tính toán hàm hợp tương đương khi di chuyển từ nút ~u~ đến LCA~(u,v)~. Cây còn lại tính toán hàm hợp tương đương khi di chuyển từ LCA~(u,v)~ tới nút ~v~. Cuối cùng, chúng ta tính cả hàm hợp với giá trị ~x~.

Cần lưu ý rằng hàm tại nút LCA~(u,v)~ chỉ nên được tính một lần.

Độ phức tạp về thời gian: ~\mathcal{O}(N \log^2 N)~.

#include <bits/stdc++.h>
#define fi first
#define se second
#define mp make_pair
using namespace std;
template <typename T1, typename T2> bool mini(T1 &a, T2 b) {
    if (a > b) {a = b; return true;} return false;
}
template <typename T1, typename T2> bool maxi(T1 &a, T2 b) {
    if (a < b) {a = b; return true;} return false;
}
const int N = 2e5 + 5; 
const int oo = 1e9;
const int mod = 998244353;

pair <int, int> operator + (pair <int, int> a, pair <int, int> b) {
    pair <int, int> res;
    res.fi = (long long) a.fi * b.fi % mod;
    res.se = ((long long) a.fi * b.se + a.se) % mod;
    return res;
}

struct segtree {
    int n;  
    vector <pair <int, int>> st;

    segtree(int _n = 0) {
        n = _n;
        st.assign((n << 2) + 5, mp(0, 0));
    }

    void update(int id, int l, int r, int pos, pair <int, int> val) {
        if (l > pos || r < pos)
            return;
        if (l == r) 
            return (void) (st[id] = val);
        int mid = (l + r) >> 1;
        update(id << 1, l, mid, pos, val);
        update(id << 1 | 1, mid + 1, r, pos, val);
        st[id] = st[id << 1] + st[id << 1 | 1];
    }

    void update(int pos, pair <int, int> val) {
        update(1, 1, n, pos, val);
    }

    pair <int, int> getsum(int id, int l, int r, int u, int v) {
        if (l > v || r < u)
            return mp(1, 0);
        if (l >= u && r <= v)
            return st[id];
        int mid = (l + r) >> 1;
        pair <int, int> lnode = getsum(id << 1, l, mid, u, v);
        pair <int, int> rnode = getsum(id << 1 | 1, mid + 1, r, u, v);
        return lnode + rnode;
    }

    pair <int, int> getsum(int l, int r) {
        if (l > r)
            swap(l, r);
        return getsum(1, 1, n, l, r);
    }
};  

segtree st, rev;

vector <int> adj[N];

pair <int, int> a[N];

int revpos[N];
int par[N];
int pos[N];
int sz[N];
int d[N];
int r[N];
int n,q,timer;

void dfs(int u, int p = -1) {
    par[u] = p;
    sz[u] = 1;
    for (int v : adj[u]) if (v != p) {
        d[v] = d[u] + 1;
        dfs(v, u);
        sz[u] += sz[v];
    }
}

void build(int u, int root, int p = -1) {
    pos[u] = ++timer;
    r[u] = root;
    int hv = 0;
    for (int v : adj[u]) 
        if (v != p && sz[hv] < sz[v])
            hv = v;
    if (hv == 0)
        return;
    build(hv, root, u);
    for (int v : adj[u])
        if (v != p && v != hv)
            build(v, v, u);
}   

int getlca(int u, int v) {
    while (r[u] != r[v]) {
        if (d[r[u]] < d[r[v]])
            swap(u, v);
        u = par[r[u]];
    }
    if (d[u] > d[v])
        swap(u, v);
    return u;
}

int main() {
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    cin >> n >> q;
    for (int i = 1; i <= n; i++)
        cin >> a[i].fi >> a[i].se;

    for (int i = 1; i < n; i++) {
        int u,v; cin >> u >> v;
        adj[u].push_back(v);
        adj[v].push_back(u);
    }
    dfs(1);
    build(1, 1);
    st = rev = segtree(n);

    for (int i = 1; i <= n; i++) {
        revpos[i] = n - pos[i] + 1;
        st.update(pos[i], a[i]);
        rev.update(revpos[i], a[i]);
    }
    while (q--) {
        int op;
        cin >> op;
        if (op == 0) {
            int u,x,y;
            cin >> u >> x >> y;
            st.update(pos[u], mp(x, y));
            rev.update(revpos[u], mp(x, y)); 
        } else {
            int u,v,x;
            cin >> u >> v >> x;
            int p = getlca(u, v);
            vector <pair <int, int>> f, g;
            while (d[r[u]] > d[r[p]]) {
                g.push_back(st.getsum(pos[r[u]], pos[u]));
                u = par[r[u]];
            }
            g.push_back(st.getsum(pos[p], pos[u]));
            reverse(g.begin(), g.end());

            while (d[r[v]] > d[r[p]]) {
                f.push_back(rev.getsum(revpos[r[v]], revpos[v]));
                v = par[r[v]];
            }

            if (v != p) 
                f.push_back(rev.getsum(revpos[p] - 1, revpos[v]));

            pair <int, int> res = mp(1, 0);
            for (pair <int, int> x : f) 
                res = res + x;

            for (pair <int, int> x : g) 
                res = res + x;


            cout << ((long long) res.fi * x + res.se) % mod << "\n";
        }
    }
    return 0;
}

Bình luận

Hãy đọc nội quy trước khi bình luận.


Không có bình luận tại thời điểm này.