Hướng dẫn giải của Sum And Query


Chỉ dùng lời giải này khi không có ý tưởng, và đừng copy-paste code từ lời giải này. Hãy tôn trọng người ra đề và người viết lời giải.
Nộp một lời giải chính thức trước khi tự giải là một hành động có thể bị ban.

Lời giải cho các subtask 1 và 2

Trước hết, chúng ta có một bổ đề:

  • Phân chia tối ưu cho ~k-1~ có thể được thu được từ ~k~ bằng cách gộp hai tập con.

Để chứng minh điều này, hãy theo dõi lời giải cho các subtask sau và suy ra từ đó.

Giả sử dãy ~\operatorname{cost}~ của ~k~ tập con là ~\{c_1,c_2,\ldots,c_k\}~. Chúng ta cần tìm hai chỉ số khác nhau ~i~ và ~j~ sao cho ~c_i + c_j - (c_i\,\&\, c_j)=c_i \,|\, c_j~ là tối thiểu, trong đó ~|~ là toán tử OR.

Để tìm cặp này, đặt ~l=29~ và lặp lại các bước sau:

  1. Nếu ~l<0~, chọn bất kỳ cặp phần tử nào làm đáp án.

  2. Chia ~X~ thành ~X_0~ và ~X_1~ trong đó ~X_1~ chứa các phần tử có bit thứ ~l~ là ~1~.

  3. Nếu ~|X_0|\ge 2~, thay thế ~X~ bằng ~X_0~, ngược lại không làm gì.

  4. Giảm ~l~ đi ~1~.

Chứng minh tính đúng đắn được để lại như một bài tập cho độc giả. Để cải thiện, ta có quan sát sau:

  • Chỉ cần xét các cặp nằm trong ~30~ phần tử bé nhất.

Chứng minh: Nhìn vào thuật toán trước, thực chất ta đang đi trên đường đi từ gốc tới một lá của trie xây dựng trên tập ~S~. Có một thay đổi là ở bước 3, nếu ~X_0=1~ thì ta sẽ mang theo phần tử duy nhất này đi cùng, việc này chỉ xảy ra tối đa ~30~ lần. Vì vậy khi dừng lại ở lá, tập ~X~ sẽ gồm đúng các phần tử trong lá đó và tối đa ~30~ phần tử khác. Dễ thấy rằng các phần tử này là các phần tử bé nhất của tập ~S~.

Độ phức tạp thời gian mới sẽ là ~\mathcal{O}(n\cdot 30^2)~.

Lời giải cho các subtask 3 và 4

Giả sử chúng ta đang làm việc trên đa tập ~X~ và sửa bài toán gốc một chút bằng cách thêm một tham số khác là ~l~:

  • Định nghĩa ~x^{(l)}~ là kết quả sau khi tắt tất cả các bit cao hơn ~l~ của một số nguyên ~x~.

  • Thay thế ~\operatorname{cost}(X)~ bằng ~\operatorname{cost}(X,l)=x_1^{(l)} \,\&\, x_2^{(l)} \,\&\,\ldots \,\&\, x_q^{(l)}~.

Ký hiệu tổng lớn nhất có thể đạt được của bài toán đã được sửa đổi của chúng ta là ~f(X, l, k)~.

Chia ~X~ thành ~X_0~ và ~X_1~ trong đó ~X_1~ chứa các phần tử có bit thứ ~l~ là ~1~. Để tính toán ~f(X, l, k)~, có năm trường hợp cần xem xét:

Trường hợp Điều kiện Giá trị
1 ~l\lt 0~ ~0~
2 ~|X_1|\lt k~ ~\sum_{x\in X_1} x^{(l)} + f(X_0, l-1, k-|X_1|)~
3 ~|X_1|\ge k~ và ~X_0~ rỗng ~k\cdot 2^l + f(X_1, l-1, k)~
4 ~|X_1|\ge k~ và ~X_0~ không rỗng ~(k-1)\cdot 2^l + f(X_1\cup \{\operatorname{cost}(X_0,l)\}, l-1, k)~

Các trường hợp 1 và 3 là hiển nhiên. Dưới đây là chứng minh của hai trường hợp còn lại:

  • Trường hợp 2: Giá trị trả về thể hiện rằng các phần tử của ~X_1~ phải nằm một mình trong các tập con của riêng chúng. Tổng cần tìm sẽ có đúng ~|X_1|~ (cao nhất có thể) số hạng với bit thứ ~l~ được bật trong tổng cần tìm, điều này rõ ràng là tối ưu.

  • Trường hợp 4: Giá trị trả về thể hiện rằng tất cả các phần tử của ~X_0~ phải ở trong cùng một tập con (có thể có thêm các số khác). Tổng cần tìm sẽ có đúng một số hạng (thấp nhất có thể) với bit thứ ~l~ tắt, điều này rõ ràng là tối ưu.

Đáp án cho một truy vấn là ~f(S, 29, k)~. Triển khai một cách ngây thơ của thuật toán này sẽ kết thúc với một độ phức tạp thời gian là ~\mathcal{O}(|S|\cdot 30)~ cho mỗi truy vấn, đủ để vượt qua subtask 3.

Để cải thiện, xem xét chia ~X~ thành ~X_{old}~ và ~X_{new}~, trong đó ~X_{old}~ chứa các phần tử từ ~S~ và ~X_{new}~ chứa các phần tử mới được tạo ra từ trường hợp 4. Lưu ý rằng ~|X_{new}|\le 30~ và ~X_{old}~ có thể được biểu diễn bằng một cấu trúc dữ liệu như trie. Điều này sẽ dẫn đến một độ phức tạp thời gian là ~\mathcal{O}(30^2)~ cho mỗi truy vấn, đủ để vượt qua subtask 4.

#include <bits/stdc++.h>

using namespace std;

int turnOffK(int n, int k) { return n & ~(1 << k); }
int firstK(int n, int k) { return n & ((1 << k) - 1); }

using ll = long long;

const int LOG = 30;

struct Node;
using NodePtr = Node *;

struct Node {
    NodePtr child[2];
    int cnt;
    long long sum;
    int andSum;

    void update(int level) {
        sum = child[0]->sum + child[1]->sum;
        sum += (ll)child[1]->cnt << level;
        andSum = child[0]->andSum & child[1]->andSum;
        if (child[0]->cnt) {
            andSum = turnOffK(andSum, level);
        }
    }
};

Node pool[10000000];
NodePtr NIL = pool;

void reset(NodePtr node) {
    node->cnt = 0;
    node->sum = 0;
    node->andSum = -1;
    node->child[0] = node->child[1] = NIL;
}

void initPool() {
    for (int i = 0; i < 10000000; i++) {
        reset(pool + i);
    }
}

void add(NodePtr &node, int level, int value, NodePtr (*newNode)()) {
    if (node == NIL) {
        node = newNode();
    }
    node->cnt++;
    if (level < 0) {
        return;
    }
    int bit = (value >> level) & 1;
    if (node->child[bit] == NIL) {
        node->child[bit] = newNode();
    }
    add(node->child[bit], level - 1, value, newNode);
    node->update(level);
}

void remove(NodePtr node, int level, int value) {
    node->cnt--;
    if (level < 0) {
        return;
    }
    int bit = (value >> level) & 1;
    remove(node->child[bit], level - 1, value);
    node->update(level);
}

int nodeIdx = 1000;
NodePtr root = NIL;

NodePtr newNode() { return pool + (++nodeIdx); }

void add(int value) { add(root, LOG - 1, value, newNode); }

void remove(int value) { remove(root, LOG - 1, value); }

int tempNodeIdx = 0;

NodePtr newTempNode() { return pool + (++tempNodeIdx); }

void addTemp(NodePtr &node, int level, int value) {
    add(node, level, value, newTempNode);
}

void resetTemp() {
    while (tempNodeIdx > 0) {
        reset(pool + tempNodeIdx--);
    }
}

long long solve(NodePtr node, NodePtr tempNode, int level, int numGroups) {
    if (level < 0) {
        return 0;
    }
    if (numGroups == 1) {
        return firstK(node->andSum & tempNode->andSum, level + 1);
    }
    if (numGroups == node->cnt + tempNode->cnt) {
        return node->sum + tempNode->sum;
    }

    int child1Cnt = node->child[1]->cnt + tempNode->child[1]->cnt;
    if (child1Cnt < numGroups) {
        long long child1Sum = node->sum - node->child[0]->sum + tempNode->sum -
                              tempNode->child[0]->sum;
        return child1Sum + solve(
                               node->child[0], tempNode->child[0], level - 1,
                               numGroups - child1Cnt
                           );
    }

    int child0Cnt = node->child[0]->cnt + tempNode->child[0]->cnt;
    if (child0Cnt) {
        int child0AndSum = node->child[0]->andSum & tempNode->child[0]->andSum;
        // Force it to go to the right child
        addTemp(tempNode, level, child0AndSum);
        return (ll(numGroups - 1) << level) +
               solve(node->child[1], tempNode->child[1], level - 1, numGroups);
    }

    return (ll(numGroups) << level) +
           solve(node->child[1], tempNode->child[1], level - 1, numGroups);
}

int main() {
    cin.tie(0)->sync_with_stdio(0);
    initPool();
    int n, q;
    cin >> n >> q;
    for (int i = 0; i < n; i++) {
        int value;
        cin >> value;
        add(value);
    }
    while (q--) {
        int cmd;
        cin >> cmd;
        if (cmd == 1) {
            int value;
            cin >> value;
            add(value);
        } else if (cmd == 2) {
            int value;
            cin >> value;
            remove(value);
        } else {
            int numGroups;
            cin >> numGroups;
            resetTemp();
            cout << solve(root, NIL, LOG - 1, numGroups) << '\n';
        }
    }
}

Bình luận

Hãy đọc nội quy trước khi bình luận.


Không có bình luận tại thời điểm này.