Hướng dẫn giải của VNOJ Round 01 - OR PAIR
Nộp một lời giải chính thức trước khi tự giải là một hành động có thể bị ban.
Subtask1: ~n \leq 500~
Duyệt mọi cặp ~(i, j)~ và tính tổng OR của ~n - 2~ số còn lại bằng cách duyệt qua mọi số có vị trí khác ~i~ và ~j~.
Độ phức tạp: ~O(n ^ 3)~.
Subtask2: ~n \leq 10^4~
Gọi ~f(l, r)~ là tổng OR của các số có vị trí từ ~l~ đến ~r~. (nếu ~l~ > ~r~ thì ~f(l, r) = 0~)
~\Rightarrow~ Tổng OR của các số có vị trị khác ~i~ và ~j~ là: ~f(1, i - 1)~ OR ~f(i + 1, j - 1)~ OR ~f(j + 1, n)~.
Ta duyệt mọi cặp ~(i, j)~ và tính nhanh tổng OR của ~n - 2~ số còn lại bằng cách sử dụng cấu trúc dữ liệu Sparse Table để tính ~f(l, r)~ trong ~O(1)~.
Độ phức tạp: ~O(n ^ 2)~.
Subtask3: Không có ràng buộc gì thêm
Ta chia bài toán thành ~2~ trường hợp:
--- Tồn tại ~k~ (~1 \leq k \leq 20~) sao cho có đúng ~1~ số có bit thứ ~k~ là ~1~:
Kết quả là ~0~ vì với mọi cặp ~(i, j)~ thì giữa ~a_i~ OR ~a_j~ và tổng OR của ~n - 2~ số còn lại sẽ có một số có bit thứ ~k~ là ~1~, một số có bit thứ ~k~ là ~0~ nên hai số không thể bằng nhau.
--- Trường hợp trên không xảy ra:
Xét một cặp ~(i, j)~ thỏa mãn, ta thấy tổng OR của ~a_i~ và ~a_j~ bằng tổng OR của ~n - 2~ số còn lại và bằng ~a_1~ OR ~a_2~ OR ~\cdots~ OR ~a_n~ = ~S~.
Ta sẽ tính kết quả của bài toán bằng cách đếm số cặp ~(i, j)~ (~i~ và ~j~ không nhất thiết phải khác nhau) sao cho ~a_i~ OR ~a_j~ = ~S~ bằng SOS DP. Tuy nhiên trong những cặp ~(i, j)~ vừa đếm được có những cặp không thỏa mãn do ~i = j~ hoặc tồn tại ~k~ (~1 \leq k \leq 20~) sao cho ~a_i~ và ~a_j~ đều có bit thứ ~k~ là ~1~ nhưng ~n - 2~ số còn lại có bit thứ ~k~ là ~0~. Trường hợp ~i = j~ thì ta có thể dễ dàng kiểm tra được, trường hợp còn lại thì ta duyệt qua tất cả các số ~k~ mà có đúng ~2~ số có bit thứ ~k~ là ~1~.
Độ phức tạp: ~O((n + A) \cdot \log A)~ với ~A~ là ~\max{a_i}~.
#include "bits/stdc++.h" using namespace std; const int NM = 1e6 + 5; const int LG = 20; int n, a[NM]; int OR, cnt[LG]; int dp[LG + 1][1 << LG], sos[1 << LG]; set<int> s[NM]; void solve() { cin >> n; for (int i = 1; i <= n; ++i) { cin >> a[i]; OR |= a[i]; for (int j = 0; j < LG; ++j) cnt[j] += (a[i] >> j) & 1; } for (int i = 0; i < LG; ++i) if (cnt[i] == 1) { cout << 0; return; } for (int i = 1; i <= n; ++i) ++dp[LG][a[i]]; for (int i = LG - 1; i >= 0; --i) for (int mask = 0; mask < (1 << LG); ++mask) { dp[i][mask] = dp[i + 1][mask]; if (!(mask & (1 << i))) dp[i][mask] += dp[i + 1][mask ^ (1 << i)]; } for (int mask = 0; mask < (1 << LG); ++mask) sos[mask] = dp[0][mask]; for (int i = 1; i <= n; ++i) s[i].insert(i); for (int i = 0; i < LG; ++i) if (cnt[i] == 2) { vector<int> tmp; for (int j = 1; j <= n; ++j) if (a[j] & (1 << i)) tmp.push_back(j); s[tmp[0]].insert(tmp[1]); s[tmp[1]].insert(tmp[0]); } long long ans = 0; for (int i = 1; i <= n; ++i) { int mask = OR ^ a[i]; ans += sos[mask]; for (auto &j: s[i]) ans -= ((a[j] & mask) == mask); } cout << ans; } int main() { ios::sync_with_stdio(false); cin.tie(nullptr); solve(); }
Bình luận