Hướng dẫn giải của Bedao OI Contest 6 - Đếm cây


Chỉ dùng lời giải này khi không có ý tưởng, và đừng copy-paste code từ lời giải này. Hãy tôn trọng người ra đề và người viết lời giải.
Nộp một lời giải chính thức trước khi tự giải là một hành động có thể bị ban.
#include <bits/stdc++.h>
using namespace std;

#define fi first
#define se second
#define _size(x) (int)x.size()
#define BIT(i, x) ((x >> i) & 1)
#define MASK(n) ((1 << n) - 1)
#define REP(i, n) for (int i = 0, _n = (n); i < _n; ++i)
#define FOR(i, a, b) for (int i = a, _b = (b); i <= _b; ++i)
#define FORD(i, a, b) for (int i = a, _b = (b); i >= _b; --i)
#define FORB1(i, mask) for (int i = mask; i > 0; i ^= i & -i)
#define FORB0(n, i, mask) for (int i = ((1 << n) - 1) ^ mask; i > 0; i ^= i & - i)
#define FORALL(i, a) for (auto i: a)
#define fastio ios_base::sync_with_stdio(0); cin.tie(0);

const int base = 202481;
const int mod = 1e9 + 7;

struct segNode {

    int l, r;
};

vector<int> pw;

struct segTree {

    struct STNode {

        int sum, sumR, sz = 1;

        STNode operator + (STNode o) const {

            return {
                (1LL * sum * pw[o.sz] + o.sum) % mod,
                (1LL * o.sumR * pw[sz] + sumR) % mod,
                sz + o.sz
            };
        }
    };

    int n;
    vector<int> L, R;
    vector<STNode> ST;

    void init(int sz) {

        n = sz;
        ST.resize(2 * n + 4);
        L.resize(2 * n + 4), R.resize(2 * n + 4);
        FOR(i, n + 1, 2 * n) L[i] = R[i] = i - n;
        FORD(i, n - 1, 1) {
            ST[i].sz = ST[i << 1].sz + ST[i << 1 | 1].sz;
            L[i] = min(L[i << 1], L[i << 1 | 1]);
            R[i] = max(R[i << 1], R[i << 1 | 1]);
        }
    }

    void modify(int p, int value) {

        for (ST[p += n] = {value, value}; p > 1; p >>= 1) {
            if (L[p ^ 1] < L[p]) ST[p >> 1] = ST[p ^ 1] + ST[p];
            else ST[p >> 1] = ST[p] + ST[p ^ 1];
        }
    }

    int get(int l, int r) {

        int curR = r;
        long long res = 0;
        for (l += n, r += n; l < r; l >>= 1, r >>= 1) {
            if (l & 1) res = (res + 1LL * ST[l++].sum * pw[curR - R[l - 1] - 1]) % mod;
            if (r & 1) res = (res + 1LL * ST[--r].sum * pw[curR - R[r] - 1]) % mod;
        }
        return res;
    }

    int getR(int l, int r) {

        int curL = l;
        long long res = 0;
        for (l += n, r += n; l < r; l >>= 1, r >>= 1) {
            if (l & 1) res = (res + 1LL * ST[l++].sumR * pw[L[l - 1] - curL]) % mod;
            if (r & 1) res = (res + 1LL * ST[--r].sumR * pw[L[r] - curL]) % mod;
        }
        return res;
    }
} st;

int n, nChain = 0, cntHLD = 0, ans = 1, comps;
vector<int> LOG, depth, child, chainHead, chainInd, posInBase, root, maxL, minR, neg(2), tr;
vector<vector<int>> adj, p, lst;
vector<vector<segNode>> seg(2);

void dfs(int u) {

    FORALL(v, adj[u]) {
        p[0][v] = u;
        depth[v] = depth[u] + 1;
        dfs(v);
        child[u] += child[v];
    }
}

void hld(int u) {

    if (chainHead[nChain] == 0) chainHead[nChain] = u;
    chainInd[u] = nChain;
    posInBase[u] = ++cntHLD;
    tr[cntHLD] = u;
    int mtVtx = - 1;
    FORALL(v, adj[u]) if (mtVtx == - 1 || child[mtVtx] < child[v]) mtVtx = v;
    if (mtVtx != - 1) hld(mtVtx);
    FORALL(v, adj[u]) {
        if (v == mtVtx) continue;
        nChain++;
        hld(v);
    }
}

int lca(int u, int v) {

    if (depth[u] > depth[v]) swap(u, v);
    int k = LOG[depth[v]];
    FORD(i, k, 0) if (depth[v] - (1 << i) >= depth[u]) v = p[i][v];
    if (u == v) return u;
    FORD(i, k, 0) if (p[i][u] != p[i][v]) {
        u = p[i][u];
        v = p[i][v];
    }
    return p[0][u];
}

int jump(int u, int l) {

    FORB1(mask, l) u = p[__builtin_ctz(mask)][u];
    return u;
}

void getSeg(int id, int u, int a, int type) {

    int cnt = _size(seg[id]);
    int uChain = chainInd[u], aChain = chainInd[a];
    while (uChain != aChain) {
        if (!type) seg[id].push_back({posInBase[u], posInBase[chainHead[uChain]]});
        else seg[id].push_back({posInBase[chainHead[uChain]], posInBase[u]});
        u = p[0][chainHead[uChain]];
        uChain = chainInd[u];
    }
    if (!type) seg[id].push_back({posInBase[u], posInBase[a]});
    else seg[id].push_back({posInBase[a], posInBase[u]});
    if (type) reverse(seg[id].begin() + cnt, seg[id].end());
}

int _pow(int a, int x) {

    if (!x) return 1;
    int t = _pow(a, x / 2);
    t = (1LL * t * t) % mod;
    if (x % 2) t = (1LL * t * a) % mod;
    return t;
}

int getroot(int u) {

    return root[u] == u ? u : (root[u] = getroot(root[u]));
}

void unite(int u, int v) {

    u = tr[u], v = tr[v];
    u = getroot(u), v = getroot(v);
    if (_size(lst[u]) < _size(lst[v])) swap(u, v);
    if (ans) ans = 1LL * ans * _pow((minR[u] - maxL[u] + 1), mod - 2) % mod;
    if (ans) ans = 1LL * ans * _pow((minR[v] - maxL[v] + 1), mod - 2) % mod;
    FORALL(value, lst[v]) {
        lst[u].push_back(value);
        st.modify(posInBase[value], u);
    }
    comps--;
    root[v] = u;
    maxL[u] = max(maxL[u], maxL[v]);
    minR[u] = min(minR[u], minR[v]);
    ans = 1LL * ans * max(0, minR[u] - maxL[u] + 1) % mod;
}

void reRoot(int a, int b, int c, int d) {

    int cntAll = abs(a - b) + 1, cnt = 1;
    while (cnt <= cntAll) {
        int l = cnt, r = cntAll + 1;
        int value1 = (a < b ? st.get(a + l - 1, a + cntAll) : st.getR(a - cntAll + 1, a - l + 2));
        int value2 = (c < d ? st.get(c + l - 1, c + cntAll) : st.getR(c - cntAll + 1, c - l + 2));
        if (value1 == value2) break;
        while (l < r) {
            int mid = (l + r) / 2;
            value1 = (a < b ? st.get(a + l - 1, a + mid) : st.getR(a - mid + 1, a - l + 2));
            value2 = (c < d ? st.get(c + l - 1, c + mid) : st.getR(c - mid + 1, c - l + 2));
            if (value1 == value2) l = mid + 1;
            else r = mid;
        }
        if (l > cntAll) break;
        unite(a < b ? a + l - 1 : a - l + 1, c < d ? c + l - 1 : c - l + 1);
        cnt = l + 1;
    }
}

void queryCondition(int &a, int &b, int &c, int &d) {

    REP(i, 2) seg[i].clear();
    int rab = lca(a, b), rcd = lca(c, d);
    getSeg(0, a, rab, 0);
    if (b != rab) getSeg(0, b, jump(b, depth[b] - depth[rab] - 1), 1);
    getSeg(1, c, rcd, 0);
    if (d != rcd) getSeg(1, d, jump(d, depth[d] - depth[rcd] - 1), 1);
    int i = 0, cur = 0;
    FORALL(value, seg[0]) {
        if (value.l < value.r) neg[0] = 1; else neg[0] = - 1;
        int cnt = abs(value.r - value.l) + 1;
        while (cnt) {
            if (seg[1][i].l < seg[1][i].r) neg[1] = 1; else neg[1] = - 1;
            int num = min(cnt, abs(seg[1][i].r - seg[1][i].l) + 1 - cur);
            reRoot(value.r - neg[0] * (cnt - 1), value.r - neg[0] * (cnt - num), seg[1][i].l + neg[1] * cur, seg[1][i].l + neg[1] * (cur + num - 1));
            cnt -= num;
            cur += num;
            if (cur == abs(seg[1][i].r - seg[1][i].l) + 1) i++, cur = 0;
        }
    }
}

int main() {
    fastio;

    cin >> n;
    LOG.resize(n + 1);
    LOG[1] = 0;
    FOR(i, 2, n) {
        LOG[i] = LOG[i - 1];
        if (i >= (1 << (LOG[i] + 1))) LOG[i]++;
    }
    pw.resize(n + 1);
    pw[0] = 1;
    FOR(i, 1, n) pw[i] = 1LL * pw[i - 1] * base % mod;
    adj.resize(n + 1);
    FOR(i, 2, n) {
        int par;
        cin >> par;
        adj[par].push_back(i);
    }
    depth.resize(n + 1);
    child.assign(n + 1, 1);
    p.assign(LOG[n] + 1, vector<int>(n + 1));
    dfs(1);
    FOR(j, 1, LOG[n]) FOR(i, 1, n) p[j][i] = p[j - 1][p[j - 1][i]];
    chainHead.resize(n + 1);
    chainInd.resize(n + 1);
    posInBase.resize(n + 1);
    tr.resize(n + 1);
    hld(1);
    root.resize(n + 1);
    lst.resize(n + 1);
    maxL.resize(n + 1);
    minR.resize(n + 1);
    st.init(n + 1);
    FOR(i, 1, n) {
        root[i] = i;
        lst[i].push_back(i);
        st.modify(posInBase[i], i);
        int l, r;
        cin >> l >> r;
        maxL[i] = l, minR[i] = r;
    }
    FOR(i, 1, n) ans = 1LL * ans * (minR[i] - maxL[i] + 1) % mod;
    int m;
    cin >> m;
    comps = n;
    REP(i, m) {
        int a, b, c, d;
        cin >> a >> b >> c >> d;
        if (ans && comps > 1) queryCondition(a, b, c, d);
        cout << ans << '\n';
    }
    return 0;
}

Bình luận

Hãy đọc nội quy trước khi bình luận.


Không có bình luận tại thời điểm này.