Hướng dẫn giải của Bedao Mini Contest 15 - 2SEG


Chỉ dùng lời giải này khi không có ý tưởng, và đừng copy-paste code từ lời giải này. Hãy tôn trọng người ra đề và người viết lời giải.
Nộp một lời giải chính thức trước khi tự giải là một hành động có thể bị ban.

Tác giả: bedao

  • Subtask 1: Duyệt trâu

  • Subtask 2: Vì ~a[i] \leq 20~ nên kết quả của ta sẽ có tối đa 20 phần tử, vậy nên ta sẽ duyệt trâu đoạn 1, lưu các phần tử của đoạn 1 dưới dạng bitmask và tìm đoạn 2 là một bitmask có nhiêu bit 1 nhất sao cho nó là mask con của phần bù của đoạn 1. Phần còn lại xin phép nhường bạn đọc

  • Subtask 3: Đầu tiên, cần tính mảng ~dp[l][r]~ là dãy con liên tiếp có các phần tử phân biệt dài nhất nếu xét các phần từ trong đoạn từ ~l~ đến ~r~. Ta có thể đơn giản dùng 2 con trỏ đến tính mảng này

Tiếp theo, cố định độ dài của dãy con liên tiếp thứ nhất ~[l, r]~ (phần tử phân biệt), vậy ta cần tìm đoạn ~[u, v]~ dài nhất sử dụng các phần tử khác với các phần tử trong đoạn thứ nhất và ~r < u \leq v~

Nhận xét: khi ta giảm từ đoạn ~[l, r]~ sang ~[l, r - 1]~, một số phần tử mới (cụ thể là bằng chính ~a[r]~) sẽ có thể được sử dụng cho đoạn thứ 2. Vì vậy, ta sẽ xử lí bằng cách cố định ~l~ và giảm dần ~r~ từ ~n~ về ~l~, khi giảm ~r~ ta sẽ duy trì tập các phân tử có thể sử dụng cho đoạn 2. Tập S sẽ gồm các vị trí ~S = \{i_1, i_2, i_3, ..., i_k\}~ và ta cần duy trì ~max(dp[u][v])~ sao cho đoạn ~[u, v]~ chỉ chứa các giá trị có thể sử dụng và đó cũng chính là độ dài lớn nhất của đoạn 2.

Code mẫu

#include <bits/stdc++.h>
using namespace std;
using ll = long long;

const int N = 5005;

struct DSU {
    vector<int> lab;
    vector<pair<int, int>> bound;

    DSU(int n) {
        lab.assign(n + 5, -1);
        bound.resize(n + 5);
        for (int i = 1; i <= n; ++ i) {
            bound[i] = {i, i};
        }
    }

    int find_set(int u) {
        return lab[u] < 0? u : lab[u] = find_set(lab[u]);
    }

    pair<int, int> merge(int u, int v) {
        u = find_set(u);
        v = find_set(v);
        if (lab[u] > lab[v]) swap(u, v);
        auto& [l, r] = bound[u];
        auto [L, R] = bound[v];
        l = min(L, l);
        r = max(r, R);
        lab[u] += lab[v];
        lab[v] = u;
        return bound[u];
    }
};

int n, a[1000005];
vector<int> pos[N];
int vis[N];
int open[N];
int dp[N][N];

void prepare() {
    vector<int> vt(a + 1, a + 1 + n);
    sort(vt.begin(), vt.end());
    vt.resize(unique(vt.begin(), vt.end()) - vt.begin());
    for (int i = 1; i <= n; ++ i) {
        pos[i].clear();
    }
    for (int i = 1; i <= n; ++ i) {
        a[i] = upper_bound(vt.begin(), vt.end(), a[i]) - vt.begin();
        pos[a[i]].push_back(i);
    }

    for (int i = n; i >= 1; -- i) {
        memset(vis, 0, sizeof vis);
        bool ok = true;
        for (int j = i; j <= n; ++ j) {
            if (vis[a[j]]) ok = false;
            vis[a[j]] = 1;
            if (ok) {
                dp[i][j] = j - i + 1;
            } else {
                dp[i][j] = max(dp[i + 1][j], dp[i][j - 1]);
            }
        }
    }
}

void update(int value, int limit, DSU& dsu, int& res) {
    for (auto idx : pos[value]) {
        if (idx <= limit) continue;
        res = max(res, 1);
        if (idx > 1 && open[idx - 1]) {
            auto [l, r] = dsu.merge(idx - 1, idx);
            res = max(res, dp[l][r]);
        }
        if (idx < n && open[idx + 1]) {
            auto [l, r] = dsu.merge(idx, idx + 1);
            res = max(res, dp[l][r]);
        }
        open[idx] = 1;
    }
}

void solve3() {

    prepare();
    int ans = 0;
    for (int i = 1; i <= n; ++ i) {
        DSU dsu(n);
        memset(vis, 0, sizeof vis);
        int cnt = 0;
        for (int j = 1; j <= i; ++ j) {
            if (++ vis[a[j]] == 2) {
                ++ cnt;
            }
        }
        memset(open, 0, sizeof open);
        int res = 0;
        for (int j = 1; j <= n; ++ j) {
            if (vis[j] == 0) {
                update(j, i, dsu, res);
            }
        }

        for (int j = 1; j <= i; ++ j) {
            if (cnt == 0) ans = max(ans, res + i - j + 1);
            -- vis[a[j]];
            if (vis[a[j]] == 1) {
                -- cnt;
            }
            if (vis[a[j]] == 0) {
                update(a[j], i, dsu, res);
            }
        }
    }
    cout << ans << "\n";
}

int f[1 << 20];
int g[1 << 20], timer = 0;

void solve2() {
    ++ timer;
    for (int i = 1 ; i <= n; ++ i) {
        -- a[i];
    }
    for (int i = 1; i <= n; ++ i) {
        int mask = 0;
        for (int j = i; j <= n; ++ j) {
            // cout << i << " " << j << " " << mask << " " << a[j] << endl;
            if (mask >> a[j] & 1) {
                break;
            }
            mask |= (1 << a[j]);
            if (g[mask] != timer) g[mask] = timer, f[mask] = 0;
            f[mask] = j - i + 1;
        }
    }
    for (int mask = 0; mask < (1 << 20); ++ mask) {
        for (int i = 0; i < 20; ++ i) {
            if (mask >> i & 1) {
                if (g[mask] != timer) g[mask] = timer, f[mask] = 0;
                if (g[mask ^ (1 << i)] != timer) g[mask ^ (1 << i)] = timer, f[mask ^ (1 << i)] = 0;
                f[mask] = max(f[mask], f[mask ^ (1 << i)]);
            }
        }
    }
    int res = 0;
    for (int mask = 0; mask < (1 << 20); ++ mask) {
        int state = ((1 << 20) - 1) ^ mask;
        if (g[mask] != timer) g[mask] = timer, f[mask] = 0;
        if (g[state] != timer) g[state] = timer, f[state] = 0;
        res = max(res, f[mask] + f[state]);
    }
    cout << res << "\n";
}

main() {
    cin.tie(0)->sync_with_stdio(0);
    // freopen("main.inp", "r", stdin);
    int tt;
    cin >> tt;
    while (tt --) {
        cin >> n;
        int mx = 0;
        for (int i = 1; i <= n; ++ i) {
            cin >> a[i];
            mx = max(mx, a[i]);
        }
        if (mx <= 20) {
            solve2();
        } else {
            solve3();
        }
    }
}

Bình luận

Hãy đọc nội quy trước khi bình luận.


Không có bình luận tại thời điểm này.