Editorial for Bedao Regular Contest 04 - MAGIC


Remember to use this editorial only when stuck, and not to copy-paste code from it. Please be respectful to the problem author and editorialist.
Submitting an official solution before solving the problem yourself is a bannable offence.

Author: bedao

Dễ thấy trong đoạn mã giả ở đề bài, ta có ~1 \le i < j < x < y \le n~.

Bài toán từ đó trở thành đếm số bộ bốn phần tử phân biệt của dãy ~a~ có tổng bằng ~0~.

Gọi ~cnt[pos][sum]~ là số lượng cặp ~(i, j)~ sao cho ~1 \le i < j \le pos~ và ~a[i] + a[j] = sum~.

Giả sử tính được mảng ~cnt~ trên. Khi đó đáp án của bài toán là tổng của các ~cnt[x - 1][-(a[x] + a[y])]~ với mọi cặp ~(x, y)~ sao cho ~3 \le x < y \le n~.

Tuy nhiên, ta không thể lưu dữ liệu vào một mảng ~cnt[pos][sum]~ được vì ~-2\times 10^6 \le sum \le 2\times 10^6~ và ~pos \le n~.

Thay vào đó, ta tối ưu bằng cách rút gọn mảng ~cnt~ thành ~cnt[sum]~ và tính mảng này theo thứ tự tăng dần của ~pos~. Mặt khác, khi ta for tới ~pos~ và update giá trị thì ~cnt[sum]~ đang lưu giá trị của ~cnt[pos][sum]~.

Ta cần lưu thêm một vector những giá trị sẽ sử dụng của mảng ~cnt~ tại vị trí ~pos~, gọi là ~qu[pos]~.

Với mọi cặp ~(i, j)~ sao cho ~i < j~, ta có ~qu[i - 1].push_back(-(a[i] + a[j]))~ (Tương ứng với ta cần cộng kết quả cho ~cnt[i - 1][-(a[i] + a[j])]~).

Sau đó, ta lần lượt for ~i~ từ ~2~ đến ~n~ và ~j~ từ ~1~ đến ~i - 1~, cập nhật ~cnt[a[i] + a[j]] += 1~.

Dễ thấy khi đã for xong với ~i=pos~ thì mảng ~cnt[sum]~ lúc này lưu giá trị của ~cnt[pos][sum]~. Ta chỉ cần for qua những phần tử trong ~qu[i]~ (gọi là ~sum~) và cộng vào đáp án lượng ~cnt[sum]~ tương ứng.

Chú ý rằng ~sum~ có thể âm nên bạn cần shift giá trị từ đoạn ~[-2\times 10^6, 2\times 10^6]~ lên ~[0, 4\times 10^6]~ để có thể sử dụng được mảng ~cnt~. (Nghĩa là ~cnt[0]~ sẽ lưu giá trị của ~cnt[-2\times 10^6]~ ban đầu)

Độ phức tạp thời gian: Vì chỉ có ~O(N^2)~ cặp ~(i,j)~ nên tổng tất cả phần tử của ~n~ vector ~qu[]~ là ~O(N^2)~, việc update mảng ~cnt~ cũng như cộng đáp án là ~O(N^2)~. Tổng là ~O(N^2)~.

Độ phức tạp bộ nhớ: ~O(2 \times max(a[i] + a[j]) + N^2)~.

Code mẫu

#pragma GCC optimize("Ofast")
#pragma GCC optimize("unroll-loops")
#pragma GCC target("avx,avx2,fma")

#include <bits/stdc++.h>

#define for1(i,a,b) for (int i = a; i <= b; i++)
#define for2(i,a,b) for (int i = a; i >= b; i--)
// #define int long long

#define sz(a) (int)a.size()
#define pii pair<int,int>
#define pb push_back

/*
__builtin_popcountll(x) : Number of 1-bit
__builtin_ctzll(x) : Number of trailing 0
*/

const long double PI = 3.1415926535897932384626433832795;
const int INF = 1000000000000000000;
const int MOD = 1000000007;
const int MOD2 = 1000000009;
const long double EPS = 1e-6;

using namespace std;

#include <ext/pb_ds/assoc_container.hpp>
using namespace __gnu_pbds;

const int N = 2e3 + 5;
int n, res;
int a[N], cnt[4000005];
gp_hash_table<int, int> mp[N];

int get1(int x) {
    return cnt[x + (int)2e6];
}
int get2(int i, int x) {
    if (mp[i].find(x) != mp[i].end()) return mp[i][x];
    return 0;
}

signed main() {

    ios_base::sync_with_stdio(false); cin.tie(NULL); cout.tie(NULL);

    // freopen("cf.inp", "r", stdin);
    // freopen("cf.out", "w", stdout);

    cin >> n;
    for1(i,1,n) cin >> a[i];

    for1(i,1,n) {
        for1(j,i + 1,n) {
            cnt[a[i] + a[j] + (int)2e6]++;
            mp[i][a[i] + a[j]]++;
            mp[j][a[i] + a[j]]++;
        }
    }

    for1(i,1,n) {
        for1(j,i + 1,n) {
            int x = -a[i] - a[j];
            res += get1(x) - get2(i, x) - get2(j, x) + (x == 0);
            // cout <<i << " " << j << " " << res << "\n";
        }
    }
    cout << res / 6;
}

Comments

Please read the guidelines before commenting.


There are no comments at the moment.