Hướng dẫn giải của VNOJ Round 01 - OR PAIR


Chỉ dùng lời giải này khi không có ý tưởng, và đừng copy-paste code từ lời giải này. Hãy tôn trọng người ra đề và người viết lời giải.
Nộp một lời giải chính thức trước khi tự giải là một hành động có thể bị ban.

Subtask1: ~n \leq 500~

Duyệt mọi cặp ~(i, j)~ và tính tổng OR của ~n - 2~ số còn lại bằng cách duyệt qua mọi số có vị trí khác ~i~ và ~j~.

Độ phức tạp: ~O(n ^ 3)~.

Subtask2: ~n \leq 10^4~

Gọi ~f(l, r)~ là tổng OR của các số có vị trí từ ~l~ đến ~r~. (nếu ~l~ > ~r~ thì ~f(l, r) = 0~)

~\Rightarrow~ Tổng OR của các số có vị trị khác ~i~ và ~j~ là: ~f(1, i - 1)~ OR ~f(i + 1, j - 1)~ OR ~f(j + 1, n)~.

Ta duyệt mọi cặp ~(i, j)~ và tính nhanh tổng OR của ~n - 2~ số còn lại bằng cách sử dụng cấu trúc dữ liệu Sparse Table để tính ~f(l, r)~ trong ~O(1)~.

Độ phức tạp: ~O(n ^ 2)~.

Subtask3: Không có ràng buộc gì thêm

Ta chia bài toán thành ~2~ trường hợp:

--- Tồn tại ~k~ (~1 \leq k \leq 20~) sao cho có đúng ~1~ số có bit thứ ~k~ là ~1~:

Kết quả là ~0~ vì với mọi cặp ~(i, j)~ thì giữa ~a_i~ OR ~a_j~ và tổng OR của ~n - 2~ số còn lại sẽ có một số có bit thứ ~k~ là ~1~, một số có bit thứ ~k~ là ~0~ nên hai số không thể bằng nhau.

--- Trường hợp trên không xảy ra:

Xét một cặp ~(i, j)~ thỏa mãn, ta thấy tổng OR của ~a_i~ và ~a_j~ bằng tổng OR của ~n - 2~ số còn lại và bằng ~a_1~ OR ~a_2~ OR ~\cdots~ OR ~a_n~ = ~S~.

Ta sẽ tính kết quả của bài toán bằng cách đếm số cặp ~(i, j)~ (~i~ và ~j~ không nhất thiết phải khác nhau) sao cho ~a_i~ OR ~a_j~ = ~S~ bằng SOS DP. Tuy nhiên trong những cặp ~(i, j)~ vừa đếm được có những cặp không thỏa mãn do ~i = j~ hoặc tồn tại ~k~ (~1 \leq k \leq 20~) sao cho ~a_i~ và ~a_j~ đều có bit thứ ~k~ là ~1~ nhưng ~n - 2~ số còn lại có bit thứ ~k~ là ~0~. Trường hợp ~i = j~ thì ta có thể dễ dàng kiểm tra được, trường hợp còn lại thì ta duyệt qua tất cả các số ~k~ mà có đúng ~2~ số có bit thứ ~k~ là ~1~.

Độ phức tạp: ~O((n + A) \cdot \log A)~ với ~A~ là ~\max{a_i}~.

#include "bits/stdc++.h"
using namespace std;

const int NM = 1e6 + 5;
const int LG = 20;

int n, a[NM];
int OR, cnt[LG];
int dp[LG + 1][1 << LG], sos[1 << LG];
set<int> s[NM];

void solve() {
    cin >> n;
    for (int i = 1; i <= n; ++i) {
        cin >> a[i];

        OR |= a[i];
        for (int j = 0; j < LG; ++j)
            cnt[j] += (a[i] >> j) & 1;
    }

    for (int i = 0; i < LG; ++i)
        if (cnt[i] == 1) {
            cout << 0;
            return;
        }

    for (int i = 1; i <= n; ++i)
        ++dp[LG][a[i]];

    for (int i = LG - 1; i >= 0; --i)
        for (int mask = 0; mask < (1 << LG); ++mask) {
            dp[i][mask] = dp[i + 1][mask];

            if (!(mask & (1 << i)))
                dp[i][mask] += dp[i + 1][mask ^ (1 << i)];
        }

    for (int mask = 0; mask < (1 << LG); ++mask)
        sos[mask] = dp[0][mask];

    for (int i = 1; i <= n; ++i)
        s[i].insert(i);

    for (int i = 0; i < LG; ++i)
        if (cnt[i] == 2) {
            vector<int> tmp;

            for (int j = 1; j <= n; ++j)
                if (a[j] & (1 << i))
                    tmp.push_back(j);

            s[tmp[0]].insert(tmp[1]);
            s[tmp[1]].insert(tmp[0]);
        }

    long long ans = 0;

    for (int i = 1; i <= n; ++i) {
        int mask = OR ^ a[i];
        ans += sos[mask];

        for (auto &j: s[i])
            ans -= ((a[j] & mask) == mask);
    }

    cout << ans;
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    solve();
}

Bình luận

Hãy đọc nội quy trước khi bình luận.


Không có bình luận tại thời điểm này.