Hướng dẫn giải của Bedao Regular Contest 04 - MAGIC
Nộp một lời giải chính thức trước khi tự giải là một hành động có thể bị ban.
Tác giả:
Dễ thấy trong đoạn mã giả ở đề bài, ta có ~1 \le i < j < x < y \le n~.
Bài toán từ đó trở thành đếm số bộ bốn phần tử phân biệt của dãy ~a~ có tổng bằng ~0~.
Gọi ~cnt[pos][sum]~ là số lượng cặp ~(i, j)~ sao cho ~1 \le i < j \le pos~ và ~a[i] + a[j] = sum~.
Giả sử tính được mảng ~cnt~ trên. Khi đó đáp án của bài toán là tổng của các ~cnt[x - 1][-(a[x] + a[y])]~ với mọi cặp ~(x, y)~ sao cho ~3 \le x < y \le n~.
Tuy nhiên, ta không thể lưu dữ liệu vào một mảng ~cnt[pos][sum]~ được vì ~-2\times 10^6 \le sum \le 2\times 10^6~ và ~pos \le n~.
Thay vào đó, ta tối ưu bằng cách rút gọn mảng ~cnt~ thành ~cnt[sum]~ và tính mảng này theo thứ tự tăng dần của ~pos~. Mặt khác, khi ta for tới ~pos~ và update giá trị thì ~cnt[sum]~ đang lưu giá trị của ~cnt[pos][sum]~.
Ta cần lưu thêm một vector những giá trị sẽ sử dụng của mảng ~cnt~ tại vị trí ~pos~, gọi là ~qu[pos]~.
Với mọi cặp ~(i, j)~ sao cho ~i < j~, ta có ~qu[i - 1].push_back(-(a[i] + a[j]))~ (Tương ứng với ta cần cộng kết quả cho ~cnt[i - 1][-(a[i] + a[j])]~).
Sau đó, ta lần lượt for ~i~ từ ~2~ đến ~n~ và ~j~ từ ~1~ đến ~i - 1~, cập nhật ~cnt[a[i] + a[j]] += 1~.
Dễ thấy khi đã for xong với ~i=pos~ thì mảng ~cnt[sum]~ lúc này lưu giá trị của ~cnt[pos][sum]~. Ta chỉ cần for qua những phần tử trong ~qu[i]~ (gọi là ~sum~) và cộng vào đáp án lượng ~cnt[sum]~ tương ứng.
Chú ý rằng ~sum~ có thể âm nên bạn cần shift giá trị từ đoạn ~[-2\times 10^6, 2\times 10^6]~ lên ~[0, 4\times 10^6]~ để có thể sử dụng được mảng ~cnt~. (Nghĩa là ~cnt[0]~ sẽ lưu giá trị của ~cnt[-2\times 10^6]~ ban đầu)
Độ phức tạp thời gian: Vì chỉ có ~O(N^2)~ cặp ~(i,j)~ nên tổng tất cả phần tử của ~n~ vector ~qu[]~ là ~O(N^2)~, việc update mảng ~cnt~ cũng như cộng đáp án là ~O(N^2)~. Tổng là ~O(N^2)~.
Độ phức tạp bộ nhớ: ~O(2 \times max(a[i] + a[j]) + N^2)~.
Code mẫu
#pragma GCC optimize("Ofast") #pragma GCC optimize("unroll-loops") #pragma GCC target("avx,avx2,fma") #include <bits/stdc++.h> #define for1(i,a,b) for (int i = a; i <= b; i++) #define for2(i,a,b) for (int i = a; i >= b; i--) // #define int long long #define sz(a) (int)a.size() #define pii pair<int,int> #define pb push_back /* __builtin_popcountll(x) : Number of 1-bit __builtin_ctzll(x) : Number of trailing 0 */ const long double PI = 3.1415926535897932384626433832795; const int INF = 1000000000000000000; const int MOD = 1000000007; const int MOD2 = 1000000009; const long double EPS = 1e-6; using namespace std; #include <ext/pb_ds/assoc_container.hpp> using namespace __gnu_pbds; const int N = 2e3 + 5; int n, res; int a[N], cnt[4000005]; gp_hash_table<int, int> mp[N]; int get1(int x) { return cnt[x + (int)2e6]; } int get2(int i, int x) { if (mp[i].find(x) != mp[i].end()) return mp[i][x]; return 0; } signed main() { ios_base::sync_with_stdio(false); cin.tie(NULL); cout.tie(NULL); // freopen("cf.inp", "r", stdin); // freopen("cf.out", "w", stdout); cin >> n; for1(i,1,n) cin >> a[i]; for1(i,1,n) { for1(j,i + 1,n) { cnt[a[i] + a[j] + (int)2e6]++; mp[i][a[i] + a[j]]++; mp[j][a[i] + a[j]]++; } } for1(i,1,n) { for1(j,i + 1,n) { int x = -a[i] - a[j]; res += get1(x) - get2(i, x) - get2(j, x) + (x == 0); // cout <<i << " " << j << " " << res << "\n"; } } cout << res / 6; }
Bình luận