Hướng dẫn giải của Bedao Regular Contest 19 - Triệu tập quân đội


Chỉ dùng lời giải này khi không có ý tưởng, và đừng copy-paste code từ lời giải này. Hãy tôn trọng người ra đề và người viết lời giải.
Nộp một lời giải chính thức trước khi tự giải là một hành động có thể bị ban.

Dễ thấy công thức quy hoạch động cho subtask ~1~ sẽ là:

  • ~S(u) = a_u \times (1 + S(v_1) + S(v_2) + ... + S(v_k))~, với ~v_i~ là các cấp dưới trực tiếp của ~u~.

Tuy vậy, để làm được bài toán với truy vấn update, ta cần nhìn nhận ra việc hàm ~S(u)~ chính bằng tổng của tích các ~a_i~ thuộc đường đi xuất phát từ ~u~.

Nói cách khác, gọi ~path(u,v)~ là tập các đỉnh trên đường đi từ ~u~ tới ~v~, gọi ~t(u,v) = \prod_{x\in path(u,v)} a_x~. Ta có:

~S(u) = \sum_{v \in subtree(u)} t(u,v)~

~\rightarrow t(1,p_u) \times S(u) = t(1,p_u) \times \sum_{v \in subtree(u)} t(u,v)~

Mà ~t(1,p_u) \times t(u,v) = t(1, v)~ với ~v \in subtree(u)~

~\rightarrow t(1,p_u) \times S(u) = \sum_{v \in subtree(u)} t(1,v)~

~\rightarrow S(u) = \frac{\sum_{v \in subtree(u)} t(1,v)}{t(1,p_u)}~

Hàm ~t(1,v)~ có thể viết lại đơn giản thành ~h(v)~ và dễ dàng tính được từ trước.

Vậy ~S(u) = \frac{\sum_{v \in subtree(u)} h(v)}{h(p_u)}~

Để tiến hành các truy vấn update và hỏi giá trị, ta sẽ trải phẳng cây ra theo Euler Tour. Ta sẽ dùng cây Segment Tree kiểm soát tổng các ~h(v)~ trên đoạn và xử lí các truy vấn như sau:

  • ~1~ ~u~ ~x~: Ta có thể dùng lazy segment tree để nhân các ~h(v)~ với ~v \in subtree(u)~ với giá trị ~\frac{x}{a_u}~.

  • ~2~ ~u~: Lấy tổng các ~h(v)~ thuộc ~subtree(u)~ và chia cho ~h(p_u)~.

Lưu ý, ta cần sử dụng thêm nghịch đảo module cho cả hai truy vấn.

#include<bits/stdc++.h>
using namespace std;
#define int long long
const int mod = 1e9+7;
const int N = 2e5+5;
inline void mul (int &a, int b) {
    a *= b;
    a %= mod;
}
inline int cdt (int a, int b) {
    int res = 1;
    while (b) {
        if (b & 1) res = res*a%mod;
        a = a*a%mod;
        b >>=1;
    }
    return res;
}
inline int inv (int a) {
    return cdt(a,mod-2);
}
int n,q;
vector<int> adj[N];
int a[N],h[N];
int ti[N],to[N],pos[N],cnt;
int f[N];
void dfs (int u, int p) {
    h[u] = h[p]*a[u]%mod;
    ti[u] = ++cnt;
    pos[cnt] = u;
    f[u] = p;
    for (auto v : adj[u]) {
        if (v == p) continue;
        dfs(v,u);
    }
    to[u] = cnt;
}
struct segtri {
    int n;
    vector<int> node,lazy;
    segtri () {}
    void build (int id, int l, int r) {
        if (l == r) {
            node[id] = h[pos[l]];
            return;
        }
        int mid = (r+l) >> 1;
        build(2*id,l,mid);
        build(2*id+1,mid+1,r);
        node[id] = (node[2*id]+node[2*id+1])%mod;
    }
    segtri (int n) :n(n) {
        node.resize(4*n+5);
        lazy.assign(4*n+5,1);
        build(1,1,n);
    }
    void down (int id) {
        int val = lazy[id];
        mul(node[2*id],val);
        mul(node[2*id+1],val);
        mul(lazy[2*id],val);
        mul(lazy[2*id+1],val);
        lazy[id] = 1;
    }
    void update (int id, int l, int r, int u, int v, int val) {
        if (l > v || r < u) return;
        if (l>=u&&r<=v) {
            mul(node[id],val);
            mul(lazy[id],val);
            return;
        }
        if (lazy[id] != 1) down(id);
        int mid = (r+l) >> 1;
        update (2*id,l,mid,u,v,val);
        update (2*id+1,mid+1,r,u,v,val);
        node[id] = (node[2*id]+node[2*id+1])%mod;
    }
    void update (int l, int r, int val) {
        update (1,1,n,l,r,val);
    }
    int get (int id, int l, int r, int u, int v) {
        if (l > v || r < u) return 0;
        if (l >= u && r <= v) return node[id];
        if (lazy[id] != 1) down(id);
        int mid = (r+l) >> 1;
        return (get(2*id,l,mid,u,v) + get(2*id+1,mid+1,r,u,v))%mod;
    }
    int get (int l, int r) {
        return get (1,1,n,l,r);
    }
}st;
signed main() {
    ios_base::sync_with_stdio(false);
    cin.tie(nullptr);
    cin >> n >> q;
    for (int i=2; i<=n; i++) {
        int x; cin >> x;
        adj[x].push_back(i);
    }
    for (int i=1; i<=n; i++) {
        cin >> a[i];
    }
    h[0] = 1;
    dfs(1,0);
    st = segtri(n);
    while (q--) {
        int t;
        cin >> t;
        if (t == 1) {
            int u,x; cin >> u >> x;
            st.update(ti[u],to[u],x*inv(a[u])%mod);
            a[u] = x;
        }
        else {
            int u; cin >> u;
            int res = st.get(ti[u],to[u]);
            if (u != 1) res *= inv(st.get(ti[f[u]],ti[f[u]]));
            res %= mod;
            cout << res << endl;
        }
    }
}

Bình luận

Hãy đọc nội quy trước khi bình luận.


Không có bình luận tại thời điểm này.