Editorial for Zero AND Sum


Remember to use this editorial only when stuck, and not to copy-paste code from it. Please be respectful to the problem author and editorialist.
Submitting an official solution before solving the problem yourself is a bannable offence.
#include <bits/stdc++.h>

using namespace std;

const int MOD = 998244353;

const int N = (1 << 20) + 10;

int add(int x, int y) { return (x += y) >= MOD ? x - MOD : x; }

int sub(int x, int y) { return (x -= y) < 0 ? x + MOD : x; }

int mul(int x, int y) { return 1LL * x * y % MOD; }

int inv[N], fac[N];
int n, k, a[N];
int f[N], g[N];

void prepare() {
    inv[1] = fac[0] = fac[1] = 1;
    for (int i = 2; i < N; i++) {
        inv[i] = MOD - mul(MOD / i, inv[MOD % i]);
        fac[i] = mul(fac[i - 1], i);
    }
}

int cal(int x) {
    int res = mul(n - 1, n);
    res = mul(res, x);
    res = mul(res, fac[n - 2]);
    res = mul(res, inv[n - x]);
    return mul(res, inv[n - x + 1]);
}

void solve() {
    cin >> n >> k;

    int and_sum = -1;
    for (int i = 0; i < n; i++) {
        cin >> a[i];
        and_sum &= a[i];
    }

    if (and_sum != 0) {
        cout << mul(n, fac[n]) << '\n';
        return;
    }

    if (n == 1) {
        cout << 0 << '\n';
        return;
    }

    int all_mask = (1 << k) - 1;
    for (int mask = 0; mask <= all_mask; mask++) {
        f[mask] = g[mask] = 0;
    }
    for (int i = 0; i < n; i++) {
        f[a[i]]++;
        g[a[i]]++;
    }

    for (int i = 0; i < k; i++) {
        for (int mask = 0; mask <= all_mask; mask++) {
            if (mask >> i & 1) {
                f[mask] += f[mask ^ (1 << i)];
            } else {
                g[mask] += g[mask ^ (1 << i)];
            }
        }
    }

    for (int mask = 0; mask <= all_mask; mask++) {
        g[mask] = cal(g[mask]);
    }

    for (int i = 0; i < k; i++) {
        for (int mask = 0; mask <= all_mask; mask++) {
            if (~mask >> i & 1) {
                g[mask] = sub(g[mask], g[mask ^ (1 << i)]);
            }
        }
    }

    int res = 0;
    for (int mask = 1; mask <= all_mask; mask++) {
        res = add(res, mul(g[mask], f[mask ^ all_mask]));
    }

    cout << res << '\n';
}

int main() {
    cin.tie(0)->sync_with_stdio(false);

    prepare();

    int t;
    cin >> t;
    while (t--) {
        solve();
    }
}

Comments

Please read the guidelines before commenting.


There are no comments at the moment.