Hướng dẫn giải của Sum And Query
Nộp một lời giải chính thức trước khi tự giải là một hành động có thể bị ban.
Lời giải cho các subtask 1 và 2
Trước hết, chúng ta có một bổ đề:
- Phân chia tối ưu cho ~k-1~ có thể được thu được từ ~k~ bằng cách gộp hai tập con.
Để chứng minh điều này, hãy theo dõi lời giải cho các subtask sau và suy ra từ đó.
Giả sử dãy ~\operatorname{cost}~ của ~k~ tập con là ~\{c_1,c_2,\ldots,c_k\}~. Chúng ta cần tìm hai chỉ số khác nhau ~i~ và ~j~ sao cho ~c_i + c_j - (c_i\,\&\, c_j)=c_i \,|\, c_j~ là tối thiểu, trong đó ~|~ là toán tử OR.
Để tìm cặp này, đặt ~l=29~ và lặp lại các bước sau:
Nếu ~l<0~, chọn bất kỳ cặp phần tử nào làm đáp án.
Chia ~X~ thành ~X_0~ và ~X_1~ trong đó ~X_1~ chứa các phần tử có bit thứ ~l~ là ~1~.
Nếu ~|X_0|\ge 2~, thay thế ~X~ bằng ~X_0~, ngược lại không làm gì.
Giảm ~l~ đi ~1~.
Chứng minh tính đúng đắn được để lại như một bài tập cho độc giả. Để cải thiện, ta có quan sát sau:
- Chỉ cần xét các cặp nằm trong ~30~ phần tử bé nhất.
Chứng minh: Nhìn vào thuật toán trước, thực chất ta đang đi trên đường đi từ gốc tới một lá của trie xây dựng trên tập ~S~. Có một thay đổi là ở bước 3, nếu ~X_0=1~ thì ta sẽ mang theo phần tử duy nhất này đi cùng, việc này chỉ xảy ra tối đa ~30~ lần. Vì vậy khi dừng lại ở lá, tập ~X~ sẽ gồm đúng các phần tử trong lá đó và tối đa ~30~ phần tử khác. Dễ thấy rằng các phần tử này là các phần tử bé nhất của tập ~S~.
Độ phức tạp thời gian mới sẽ là ~\mathcal{O}(n\cdot 30^2)~.
Lời giải cho các subtask 3 và 4
Giả sử chúng ta đang làm việc trên đa tập ~X~ và sửa bài toán gốc một chút bằng cách thêm một tham số khác là ~l~:
Định nghĩa ~x^{(l)}~ là kết quả sau khi tắt tất cả các bit cao hơn ~l~ của một số nguyên ~x~.
Thay thế ~\operatorname{cost}(X)~ bằng ~\operatorname{cost}(X,l)=x_1^{(l)} \,\&\, x_2^{(l)} \,\&\,\ldots \,\&\, x_q^{(l)}~.
Ký hiệu tổng lớn nhất có thể đạt được của bài toán đã được sửa đổi của chúng ta là ~f(X, l, k)~.
Chia ~X~ thành ~X_0~ và ~X_1~ trong đó ~X_1~ chứa các phần tử có bit thứ ~l~ là ~1~. Để tính toán ~f(X, l, k)~, có năm trường hợp cần xem xét:
Trường hợp | Điều kiện | Giá trị |
---|---|---|
1 | ~l\lt 0~ | ~0~ |
2 | ~|X_1|\lt k~ | ~\sum_{x\in X_1} x^{(l)} + f(X_0, l-1, k-|X_1|)~ |
3 | ~|X_1|\ge k~ và ~X_0~ rỗng | ~k\cdot 2^l + f(X_1, l-1, k)~ |
4 | ~|X_1|\ge k~ và ~X_0~ không rỗng | ~(k-1)\cdot 2^l + f(X_1\cup \{\operatorname{cost}(X_0,l)\}, l-1, k)~ |
Các trường hợp 1 và 3 là hiển nhiên. Dưới đây là chứng minh của hai trường hợp còn lại:
Trường hợp 2: Giá trị trả về thể hiện rằng các phần tử của ~X_1~ phải nằm một mình trong các tập con của riêng chúng. Tổng cần tìm sẽ có đúng ~|X_1|~ (cao nhất có thể) số hạng với bit thứ ~l~ được bật trong tổng cần tìm, điều này rõ ràng là tối ưu.
Trường hợp 4: Giá trị trả về thể hiện rằng tất cả các phần tử của ~X_0~ phải ở trong cùng một tập con (có thể có thêm các số khác). Tổng cần tìm sẽ có đúng một số hạng (thấp nhất có thể) với bit thứ ~l~ tắt, điều này rõ ràng là tối ưu.
Đáp án cho một truy vấn là ~f(S, 29, k)~. Triển khai một cách ngây thơ của thuật toán này sẽ kết thúc với một độ phức tạp thời gian là ~\mathcal{O}(|S|\cdot 30)~ cho mỗi truy vấn, đủ để vượt qua subtask 3.
Để cải thiện, xem xét chia ~X~ thành ~X_{old}~ và ~X_{new}~, trong đó ~X_{old}~ chứa các phần tử từ ~S~ và ~X_{new}~ chứa các phần tử mới được tạo ra từ trường hợp 4. Lưu ý rằng ~|X_{new}|\le 30~ và ~X_{old}~ có thể được biểu diễn bằng một cấu trúc dữ liệu như trie. Điều này sẽ dẫn đến một độ phức tạp thời gian là ~\mathcal{O}(30^2)~ cho mỗi truy vấn, đủ để vượt qua subtask 4.
#include <bits/stdc++.h> using namespace std; int turnOffK(int n, int k) { return n & ~(1 << k); } int firstK(int n, int k) { return n & ((1 << k) - 1); } using ll = long long; const int LOG = 30; struct Node; using NodePtr = Node *; struct Node { NodePtr child[2]; int cnt; long long sum; int andSum; void update(int level) { sum = child[0]->sum + child[1]->sum; sum += (ll)child[1]->cnt << level; andSum = child[0]->andSum & child[1]->andSum; if (child[0]->cnt) { andSum = turnOffK(andSum, level); } } }; Node pool[10000000]; NodePtr NIL = pool; void reset(NodePtr node) { node->cnt = 0; node->sum = 0; node->andSum = -1; node->child[0] = node->child[1] = NIL; } void initPool() { for (int i = 0; i < 10000000; i++) { reset(pool + i); } } void add(NodePtr &node, int level, int value, NodePtr (*newNode)()) { if (node == NIL) { node = newNode(); } node->cnt++; if (level < 0) { return; } int bit = (value >> level) & 1; if (node->child[bit] == NIL) { node->child[bit] = newNode(); } add(node->child[bit], level - 1, value, newNode); node->update(level); } void remove(NodePtr node, int level, int value) { node->cnt--; if (level < 0) { return; } int bit = (value >> level) & 1; remove(node->child[bit], level - 1, value); node->update(level); } int nodeIdx = 1000; NodePtr root = NIL; NodePtr newNode() { return pool + (++nodeIdx); } void add(int value) { add(root, LOG - 1, value, newNode); } void remove(int value) { remove(root, LOG - 1, value); } int tempNodeIdx = 0; NodePtr newTempNode() { return pool + (++tempNodeIdx); } void addTemp(NodePtr &node, int level, int value) { add(node, level, value, newTempNode); } void resetTemp() { while (tempNodeIdx > 0) { reset(pool + tempNodeIdx--); } } long long solve(NodePtr node, NodePtr tempNode, int level, int numGroups) { if (level < 0) { return 0; } if (numGroups == 1) { return firstK(node->andSum & tempNode->andSum, level + 1); } if (numGroups == node->cnt + tempNode->cnt) { return node->sum + tempNode->sum; } int child1Cnt = node->child[1]->cnt + tempNode->child[1]->cnt; if (child1Cnt < numGroups) { long long child1Sum = node->sum - node->child[0]->sum + tempNode->sum - tempNode->child[0]->sum; return child1Sum + solve( node->child[0], tempNode->child[0], level - 1, numGroups - child1Cnt ); } int child0Cnt = node->child[0]->cnt + tempNode->child[0]->cnt; if (child0Cnt) { int child0AndSum = node->child[0]->andSum & tempNode->child[0]->andSum; // Force it to go to the right child addTemp(tempNode, level, child0AndSum); return (ll(numGroups - 1) << level) + solve(node->child[1], tempNode->child[1], level - 1, numGroups); } return (ll(numGroups) << level) + solve(node->child[1], tempNode->child[1], level - 1, numGroups); } int main() { cin.tie(0)->sync_with_stdio(0); initPool(); int n, q; cin >> n >> q; for (int i = 0; i < n; i++) { int value; cin >> value; add(value); } while (q--) { int cmd; cin >> cmd; if (cmd == 1) { int value; cin >> value; add(value); } else if (cmd == 2) { int value; cin >> value; remove(value); } else { int numGroups; cin >> numGroups; resetTemp(); cout << solve(root, NIL, LOG - 1, numGroups) << '\n'; } } }
Bình luận