Hướng dẫn giải của Path Queries


Chỉ dùng lời giải này khi không có ý tưởng, và đừng copy-paste code từ lời giải này. Hãy tôn trọng người ra đề và người viết lời giải.
Nộp một lời giải chính thức trước khi tự giải là một hành động có thể bị ban.

Subtask 1

Với các truy vấn loại ~2~, ta chỉ cần dùng các thuật toán DFS, BFS,.. để tìm ra tổng giá trị trên đường đi từ nút ~1~ đến nút ~s~.

Độ phức tạp : ~O(n \cdot q)~.

Subtask 2

Nhận xét : Sau khi thay đổi giá trị của một nút ~s~, ta nhận thấy rằng chỉ có những nút thuộc cây con gốc nút ~s~ mới bị ảnh hưởng tới.

Ta chuẩn bị trước ~f[i]~ ~(1 \le i \le n)~ là tổng giá trị từ nút ~1~ đến nút ~i~.

Gọi ~g[i]~ ~(1 \le i \le n)~ là tổng giá trị ảnh hưởng của các nút khác đến nút thứ i. Ta có thể duy trì mảng này bằng cây phân đoạn.

  • Với mỗi truy vấn ~1~ ~s~ ~x~ : Ta cập nhật các ~g[u]~ với ~u~ là nút thuộc cây con của nút ~s~.

  • Với mỗi truy vấn ~2~ ~s~ : Kết quả của ta sẽ là ~f[s] + g[s]~.

Độ phức tạp : ~O(q \cdot log(n))~

#include <bits/stdc++.h>
using namespace std;

#define int ll
#define vt vector
#define TASK ""
#define pb push_back
#define all(x) begin(x), end(x)
#define sz(x) (int)(x).size()
#define f first
#define s second
#define mp make_pair
#define MASK(i) (1ll << (i))
#define BIT(x, i) (((x) >> (i)) & 1ll)

#define FOR(i, l, r) for (int i = (l); i <= (r); ++i)
#define rep(i, r, l) for (int i = (r); i >= (l); --i)

typedef long long ll;
typedef pair<int,int> ii;
const int INF = 1e9;
const int mod = 1e9 + 7;
const int N = 200005;
const int LOG = 20;

void setIO() {
    ios_base::sync_with_stdio(0); cin.tie(0); cout.tie(0);
}
/***------------------------- END ---------------------------***/
int n, St[N], Ft[N], it[4*N], dt[4*N], a[N], id;
int H[N], Len[N], Par[N][LOG];
vector<int> adj[N];

void dfs(int u)
{
    St[u] = ++id;
    for (auto v : adj[u])
        if (v != Par[u][0])
        {
            Par[v][0] = u;
            H[v] = H[u]+1;
            Len[v] = Len[u] + a[v];
            dfs(v);
        }
    Ft[u] = id;
}

int LCA(int u, int v) {
    if (H[u] < H[v]) return LCA(v, u);
    rep(j, LOG, 0)
        if (H[Par[u][j]] >= H[v]) u = Par[u][j];
    if (u == v) return u;
    rep(j, LOG, 0)
        if (Par[u][j] != Par[v][j]) {
            u = Par[u][j];
            v = Par[v][j];
        }
    return Par[u][0];
}

void Update (int r, int k, int l, int u, int v, int val) {
    if (v < k || u > l)
     return;
    if (u <= k && l <= v) {
        dt[r] += val;
        return;
    }
    int g = (k + l) / 2;
    dt[2 * r] += dt[r];
    dt[2 * r + 1] += dt[r];
    dt[r] = 0;

    Update (2 * r, k, g, u, v, val);
    Update (2 * r + 1, g + 1, l, u, v, val);

    it[r] = (it[2 * r] + (g - k + 1) * dt[2 * r]) +
                (it[2 * r + 1] + (l - g) * dt[2 * r + 1]);
}

int Get (int r, int k, int l, int u, int v) {
    if (v < k || u > l)
        return 0;
    if (u <= k && l <= v)
        return it[r] + (l - k + 1) * dt[r];

    int g = (k + l) / 2;
    dt[2 * r] += dt[r];
    dt[2 * r + 1] += dt[r];
    dt[r] = 0;
    int tL = Get (2 * r, k, g, u, v);
    int tR = Get (2 * r + 1, g + 1, l, u, v);

    it[r] = (it[2 * r] + (g - k + 1) * dt[2 * r]) +
                (it[2 * r + 1] + (l - g) * dt[2 * r + 1]);
    return tL + tR;
}


void test_case(int x) {
    if (x == 1) {
        int u, val;
        cin >> u >> val;
        Update(1, 1, n, St[u], Ft[u], val - a[u]);
        a[u] = val;
        return;
    }
    int u;
    cin >> u;
    cout << Len[u] + Get(1, 1, n, St[u], St[u]) << "\n";
}

signed main() {
    setIO();
    int TC = 1;
    cin >> n >> TC;
    FOR(i, 1, n)
        cin >> a[i];

    FOR(i, 2, n) {
        int u, v;
        cin >> u >> v;
        adj[u].pb(v);
        adj[v].pb(u);
    }
    Len[1] = a[1];
    Par[1][0] = 1;
    dfs(1);
    FOR(j, 1, LOG)
        FOR(i, 1 , n) Par[i][j] = Par[Par[i][j-1]][j-1];

    while (TC--) {
        int x;
        cin >> x;
        test_case(x);
    }

    return 0;
}

Bình luận

Hãy đọc nội quy trước khi bình luận.


Không có bình luận tại thời điểm này.