Hướng dẫn giải của Bedao Grand Contest 15 - Rollback


Chỉ dùng lời giải này khi không có ý tưởng, và đừng copy-paste code từ lời giải này. Hãy tôn trọng người ra đề và người viết lời giải.
Nộp một lời giải chính thức trước khi tự giải là một hành động có thể bị ban.

Đối với subtask ~1, 2, 3~ (~n \le 1000~), ta có thể tính trươc mảng ~C(x)~ trong ~O(N^{2})~.

Nhận xét: Do ~|a_i| \le 100000~ nên ~|a_i + a_j| \le 200000~, ta có thể loại các truy vấn ~add/del~ có ~x~ nằm ngoài khoảng này.

Subtask 1: Do ~q \le 1000~ nên tập hợp ~S~ chỉ có tối đa ~1000~ phần tử. Ta có thể lưu lại tập hợp ~S~ dưới dạng vector ở mỗi truy vấn. Các thao tác ~add, del~ có thể sử dụng các hàm push_back(), erase() và thao tác ~roll~ có thể gán ~S~ lại bằng giá trị cũ đã được lưu. Độ phức tạp ~O(N^{2} + N \times Q)~.

Subtask 2: Ta có thể sử dụng Segment tree/BIT để tính tổng (thao tác ~del~ ~x~ thực chất là trừ bớt đi 1 lần ~C(x)~). Độ phức tạp: ~O(N^{2} + QlogN)~.

Subtask 3: Để có thể thực hiện thao tác ~roll~, ta có thể dùng cấu trúc dữ liệu Persistent Segment tree và lưu lại vị trí node gốc của cây trong mỗi truy vấn và version hiện tại. Mỗi khi thực hiện thao tác ~roll~ ~x~ ta chỉ cần gán ~currentVer = x - 1~. Độ phức tạp: ~O(N^{2} + QlogN)~. Ngoài ra ta cũng có thể dựng đồ thị các truy vấn để xử lí.

Subtask 4: Với subtask cuối, việc còn lại của ta chỉ là tính trước mảng C trong ~O(NlogN)~. Việc tính toán có thể bằng thuật toán nhân nhanh đa thức (Fast Fourier Transform):

Xét đa thức: ~B = b_0x^{0} + b_1x^{1} + ... + b_{n-1}x^{n-1}~ (Trong đó ~b_i~ là số phần tử mang giá trị ~i~ trong mảng A)

Xét ~C = B * B = c_0x^{0} + c_1x^{1} + ... + c_{2n-2}x^{2n-2}~, ta có:

~c_i = \sum_{j + k = i} b_j*b_k~ chính là số cặp ~(x, y)~ sao cho ~a_x+a_y=i~ (cần xử lí thêm khi đếm số cặp với ~x < y~).

Từ đó, để tính trước mảng ~C~ ta chỉ cần nhân nhanh đa thức ~B~ và ~B~. Cách cái đặt FFT có thể được tham khảo tại đây.

Độ phức tạp: ~O(NlogN+QlogN)~

#include <bits/stdc++.h>

using namespace std;

#define int long long

// https://cp-algorithms.com/algebra/fft.html
namespace FFT {
    using cd = complex<double>;
    const double PI = acos(-1);

    void fft(vector<cd> & a, bool invert) {
        int n = a.size();

        for (int i = 1, j = 0; i < n; i++) {
            int bit = n >> 1;
            for (; j & bit; bit >>= 1)
                j ^= bit;
            j ^= bit;

            if (i < j)
                swap(a[i], a[j]);
        }

        for (int len = 2; len <= n; len <<= 1) {
            double ang = 2 * PI / len * (invert ? -1 : 1);
            cd wlen(cos(ang), sin(ang));
            for (int i = 0; i < n; i += len) {
                cd w(1);
                for (int j = 0; j < len / 2; j++) {
                    cd u = a[i+j], v = a[i+j+len/2] * w;
                    a[i+j] = u + v;
                    a[i+j+len/2] = u - v;
                    w *= wlen;
                }
            }
        }

        if (invert) {
            for (cd & x : a)
                x /= n;
        }
    }

    vector<int> multiply(vector<int> const& a, vector<int> const& b) {
        vector<cd> fa(a.begin(), a.end()), fb(b.begin(), b.end());
        int n = 1;
        while (n < (int)(a.size() + b.size()))
            n <<= 1;
        fa.resize(n);
        fb.resize(n);

        fft(fa, false);
        fft(fb, false);
        for (int i = 0; i < n; i++)
            fa[i] *= fb[i];
        fft(fa, true);

        vector<int> result(n);
        for (int i = 0; i < n; i++)
            result[i] = round(fa[i].real());
        return result;
    }
}

pair<long long, int>& operator+=(pair<long long, int>& a, const pair<long long, int>& b) {
    a.first += b.first;
    a.second += b.second;
    return a;
}
pair<long long, int> operator+(pair<long long, int> a, const pair<long long, int>& b) {
    a += b;
    return a;
}
// https://cp-algorithms.com/data_structures/segment_tree.html#preserving-the-history-of-its-values-persistent-segment-tree
namespace PersistentSegTree {
    struct Vertex {
        Vertex *l, *r;
        pair<long long, int> sum;

        Vertex(pair<long long, int> val) : l(nullptr), r(nullptr), sum(val) {}
        Vertex(Vertex *l, Vertex *r) : l(l), r(r), sum({0, 0}) {
            if (l) sum += l->sum;
            if (r) sum += r->sum;
        }
    };

    Vertex* build(int tl, int tr) {
        if (tl == tr)
            return new Vertex(0, 0);
        int tm = (tl + tr) / 2;
        return new Vertex(build(tl, tm), build(tm+1, tr));
    }

    pair<long long, int> get_sum(Vertex* v, int tl, int tr, int l, int r) {
        if (l > r)
            return {0, 0};
        if (l == tl && tr == r)
            return v->sum;
        int tm = (tl + tr) / 2;
        return get_sum(v->l, tl, tm, l, min(r, tm))
            + get_sum(v->r, tm+1, tr, max(l, tm+1), r);
    }

    Vertex* update(Vertex* v, int tl, int tr, int pos, pair<long long, int> new_val) {
        if (tl == tr)
            return new Vertex(new_val);
        int tm = (tl + tr) / 2;
        if (pos <= tm)
            return new Vertex(update(v->l, tl, tm, pos, new_val), v->r);
        else
            return new Vertex(v->l, update(v->r, tm+1, tr, pos, new_val));
    }
}

int32_t main() {
    ios_base::sync_with_stdio(false);
    cin.tie(0);
    int n, q;
    cin >> n >> q;
    const int MAXVALUE = 100'000;
    vector<long long> a((MAXVALUE << 1) + 1, 0);
    for (int i = 0; i < n; i++) {
        int x;
        cin >> x;
        a[x + MAXVALUE]++;
    }
    vector<long long> c = FFT::multiply(a, a);
    for (int i = 0; i < MAXVALUE * 4 + 1; ++i) {
        int x = i - MAXVALUE * 2;
        if (x % 2 == 0) {
            c[i] -= a[x / 2 + MAXVALUE];
        }
        c[i] >>= 1;
    }
    vector<PersistentSegTree::Vertex*> roots(q + 1);
    roots[0] = PersistentSegTree::build(0, (int)c.size() - 1);
    for (int i = 1; i <= q; ++i) {
        roots[i] = roots[i - 1];
        string op;
        cin >> op;
        if (op == "add") {
            int x;
            cin >> x;
            x += MAXVALUE * 2;
            if (x < 0 || x >= (int)c.size()) {
                continue;
            }
            auto value = PersistentSegTree::get_sum(roots[i], 0, (int)c.size() - 1, x, x);
            value += pair<long long, int> (c[x], 1);
            roots[i] = PersistentSegTree::update(roots[i], 0, (int)c.size() - 1, x, value);
        }
        else if (op == "del") {
            int x;
            cin >> x;
            x += MAXVALUE * 2;
            if (x < 0 || x >= (int)c.size()) {
                continue;
            }
            auto value = PersistentSegTree::get_sum(roots[i], 0, (int)c.size() - 1, x, x);
            int cnt = value.second;
            if (cnt == 0) {
                continue;
            }
            value += pair<long long, int>(-c[x], -1);
            roots[i] = PersistentSegTree::update(roots[i], 0, (int)c.size() - 1, x, value);
        }
        else if (op == "ask") {
            int l, r;
            cin >> l >> r;
            l += MAXVALUE * 2;
            r += MAXVALUE * 2;
            if (l < 0) {
                l = 0;
            }
            if (r >= (int)c.size()) {
                r = (int)c.size() - 1;
            }
            cout << PersistentSegTree::get_sum(roots[i], 0, (int)c.size() - 1, l, r).first << '\n';
        }
        else {
            int k;
            cin >> k;
            roots[i] = roots[k - 1];
        }
    }
    cerr << "Time elapsed: " << 1.0 * clock() / CLOCKS_PER_SEC << " s.\n";
}

Bình luận

Hãy đọc nội quy trước khi bình luận.


Không có bình luận tại thời điểm này.