Hướng dẫn giải của Bedao Grand Contest 13 - FLOWER


Chỉ dùng lời giải này khi không có ý tưởng, và đừng copy-paste code từ lời giải này. Hãy tôn trọng người ra đề và người viết lời giải.
Nộp một lời giải chính thức trước khi tự giải là một hành động có thể bị ban.

Vì ~N~ khá bé, ta có thể cố định ~2~ hàng ~x_1, x_2~ và nhiệm vụ còn lại là tìm số cặp cột ~y_1, y_2~ thoả mãn điều kiện. Giả sử ta đang duyệt cột ~y~, ta có thể tìm cột xa nhất ~L_y~ ~(L_y \le y_2)~ sao cho

~H_{x_2, y} > max(H_{x1, L_y+1:y}, H_{x1+1:x2-1, L_y:y}, H_{x2, L_y:y-1})~

bằng cách chia nhị phân. Tương tự với mỗi cột ~y~ ta cũng có thể tìm cột xa nhất ~R_y~ ~(R_y \ge y_1)~ sao cho:

~H_{x_1, y} > max(H_{x1, y+1:R_y}, H_{x1+1:x2-1, y:R_y}, H_{x2, y:R_y-1})~

Giờ với mỗi ~y_2~, điều kiện để có các ~y_1~ thoả mãn là

~y1 \le y2~
~y_1 \ge L_{y_2}~
~R_{y_1} \ge y_2~

Duyệt các cặp ~(L_{y_2}, y_2)~ theo thứ tự ~y_2~ tăng dần. Khi dịch ~y_2~ từ ~y~ sang ~y + 1~, thêm ~(y + 1, R_{y + 1})~ và bỏ các cặp ~(y_1, R_{y_1})~ sao cho ~R_{y_1} < y + 1~ khỏi cấu trúc dữ liệu. Ta chỉ cần đếm số cặp ~(y_1, R_{y_1})~ trong cấu trúc dữ liệu thoả mãn ~y_1 \ge L_{y_2}~

Tổng độ phức tạp: ~O(N^2 \times MlogM)~

Code mẫu

#include <bits/stdc++.h>
#define BIT(x, i) (((x)>>(i))&1)
#define MASK(i) (1LL<<(i))
#define fi first
#define se second
#define all(x) x.begin(), x.end()
#define mp make_pair
#define pb push_back
#define TASK ""

using namespace std;
typedef long long ll;
typedef pair<int, int> pii;
typedef pair<ll, ll> pll;
typedef vector<int> vi;
typedef vector<pii> vii;
typedef vector<ll> vll;
typedef vector<pll> vlll;

template<typename T1, typename T2> bool mini(T1 &a, T2 b) {if(a>b) a=b; else return 0; return 1;}
template<typename T1, typename T2> bool maxi(T1 &a, T2 b) {if(a<b) a=b; else return 0; return 1;}
template<typename T1> T1 abs(T1 a) {if(a<0) a*=-1; return a;}

mt19937_64 rd(chrono::steady_clock::now().time_since_epoch().count());
ll rand(ll l, ll h){return l + ll(rd()) * rd() %(h-l+1);}

const int N = 1e5+7;

int n, m;
vector<vi> a;
vector<pii> id;

void compress()
{
    vi b;
    for(auto &v: a) for(auto &x: v) b.push_back(x);
    sort(all(b));
    for(auto &v: a) for(auto &x: v) x = lower_bound(all(b), x) - b.begin();

    id.assign(n*m, {0, 0});
    for(int i=0; i<n; i++) for(int j=0; j<m; j++) id[a[i][j]] = {i, j};
}

int findRootL(int j, int &x, vi &cnt, vi &Max, vi &root)
{
    int &ans = root[j];
    if(ans && cnt[ans-1] == -1 && x >= Max[ans-1])
    {
        ans--;
        return ans = findRootL(ans, x, cnt, Max, root);
    }
    return ans;
}

void extendL(int &j, int &x, vi &cnt, int x1, int x2)
{
    if(j == 0) return;
    if(cnt[j-1] != x2-x1) return;
    if(a[x1][j-1] < x) return;
    j--;
}

void prepL(vi &L, int x1, int x2)
{
    vi cnt(m, 0), Max(m, 0), root(m, 0);

    for(int i=0; i<m; i++) root[i] = i+1;

    for(int x=0; x<n*m; x++) {
        int i = id[x].fi;
        int j = id[x].se;

        if(i < x1) continue;
        if(i > x2) continue;

        cnt[j]++;
        maxi(Max[j], x);

        if(cnt[j] < x2-x1+1) continue;

        cnt[j] = -1;

        if(i!=x2)
        {
            L[j] = j+1;
        }
        else
        {
            L[j] = findRootL(j, x, cnt, Max, root);
            extendL(L[j], x, cnt, x1, x2);
        }

        if(L[j] <= j) continue;
        bool ok = 1;
        for(int x3 = x1+1; x3 < x2; x3++) {
            if(!ok) break;
            ok &= a[x3][j] < a[x2][j];
        }
        if(ok) L[j] = j;
    }
}

int findRootR(int j, int &x, vi &cnt, vi &Max, vi &root)
{
    int &ans = root[j];
    if(ans<m-1 && cnt[ans+1] == -1 && x >= Max[ans+1])
    {
        ans++;
        return ans = findRootR(ans, x, cnt, Max, root);
    }
    return ans;
}

void extendR(int &j, int &x, vi &cnt, int x1, int x2)
{
    if(j == m-1) return;
    if(cnt[j+1] != x2-x1) return;
    if(a[x2][j+1] < x) return;
    j++;
}

void prepR(vi &R, int x1, int x2)
{
    vi cnt(m, 0), Max(m, 0), root(m, 0);

    for(int i=0; i<m; i++) root[i] = i-1;

    for(int x=0; x<n*m; x++) {
        int i = id[x].fi;
        int j = id[x].se;

        if(i < x1) continue;
        if(i > x2) continue;

        cnt[j]++;
        maxi(Max[j], x);

        if(cnt[j] < x2-x1+1) continue;

        cnt[j] = -1;

        if(i!=x1)
        {
            R[j] = j-1;
        }
        else
        {
            R[j] = findRootR(j, x, cnt, Max, root);
            extendR(R[j], x, cnt, x1, x2);
        }

        if(R[j] >= j) continue;
        bool ok = 1;
        for(int x3 = x1+1; x3 < x2; x3++) {
            if(!ok) break;
            ok &= a[x3][j] < a[x1][j];
        }
        if(ok) R[j] = j;
    }
}

int bit[N];

void add(int x, int v)
{
    for(x++; x<=m; x += x&-x) bit[x] += v;
}

int get(int x)
{
    int ans = 0;
    for(x++; x; x -= x&-x) ans += bit[x];
    return ans;
}

long long calc(vi &L, vi &R)
{
    long long ans = 0;

    priority_queue<pii> Q;
    for(int i=m-1; i>=0; i--) {
        Q.push({L[i], i});
        add(i, 1);
        while(Q.size())
        {
            pii p = Q.top();
            Q.pop();

            if(p.fi > i) add(p.se, -1);
            else
            {
                Q.push(p);
                break;
            }
        }
        ans += get(R[i]);
    }
    while(Q.size())
    {
        pii p = Q.top();
        Q.pop();
        add(p.se, -1);
    }
    return ans;
}

long long solve(int x1, int x2)
{
    vi L(m, -1), R(m, -1);

    prepL(L, x1, x2);
    prepR(R, x1, x2);
    return calc(L, R);
}

void proc(const int &TTT)
{
    cin>>n>>m;
    a.assign(n, vi(m, 0));
    for(auto &v: a) for(auto &x: v) cin>>x;

    compress();

    long long ans = 0;

    for(int x1 = 0; x1 < n; x1++)
    {
        for(int x2 = x1; x2 < n; x2++)
        {
            ans += solve(x1, x2);
        }
    }

    cout<<ans;
}

int main()
{
//    freopen(TASK".inp", "r", stdin);
//    freopen(TASK".out", "w", stdin);
    ios_base::sync_with_stdio(0); cin.tie(0);
    int numTest = 1;
//    cin>>numTest;
    for(int i=1; i<=numTest; i++) proc(i);
    return 0;
}

Bình luận

Hãy đọc nội quy trước khi bình luận.


Không có bình luận tại thời điểm này.