Editorial for Phân nhóm


Remember to use this editorial only when stuck, and not to copy-paste code from it. Please be respectful to the problem author and editorialist.
Submitting an official solution before solving the problem yourself is a bannable offence.

Lời giải

Tác giả: Lê Anh Đức - A2K42-PBC

Nhận xét 1

Nếu tồn tại hai cặp số ~(x_1, y_1)~ và ~(x_2, y_2)~ mà ~x_1 \gt x_2~ và ~y_1 \gt y_2~ thì ta nói cặp số ~(x_2, y_2)~ là không tiềm năng, và có thể loại bỏ không cần quan tâm tới nó. Bởi vì ta có thể cho nó vào cùng nhóm với cặp đầu tiên mà không làm tồi đi kết quả.

Như vậy ta có thể sắp xếp lại các cặp số tăng dần theo ~x~. Sau đó sử dụng stack để loại bỏ đi những cặp số không tiềm năng, cuối cùng còn lại một dãy các cặp với ~x~ tăng dần và ~y~ giảm dần. Từ đây ta chỉ cần giải bài toán cho dãy các cặp số này.

Nhận xét 2

Nếu ta có hai cặp số ~(x_1, y_1)~ và ~(x_2, y_2)~ thuộc cùng một nhóm, thì ta có thể thêm vào nhóm đấy các cặp số ~(x, y)~ mà ~x_1 \le x \le x_2~ và ~y_1 \ge y \ge y_2~ mà không làm tồi đi kết quả. Như vậy có nhận xét: các nhóm được phân hoạch gồm các phần tử liên tiếp nhau.

Quy hoạch động

Với hai nhận xét trên, ta đã có thể có được thuật toán QHĐ đơn giản với độ phức tạp đa thức. Gọi ~F(i)~ là chi phí nhỏ nhất để phân nhóm các phần tử có chỉ số không quá ~i~. Công thức truy hồi:

~F(i) = min[F(j) + x_i * y_{j+1}]~ với ~0 \le j \lt i~

Có thể cài đặt thuật toán trên với độ phức tạp ~O(N^2)~, tuy nhiên như vậy vẫn chưa đủ tốt với giới hạn của đề bài.

Áp dụng bao lồi

Đặt ~y_j=a~, ~x_i=x~, và ~F(j)=b~. Rõ ràng ta cần cực tiểu hóa một hàm bậc nhất ~y=a*x+b~ bằng việc chọn ~j~ hợp lí. Đồng thời trong bài toán này thì hệ số góc ~a~ của các đường thẳng là giảm dần, như vậy có thể áp dụng trực tiếp convex hull trick. Để ý một tí là các truy vấn (các giá trị ~x~) là tăng dần, nên ta không cần phải tìm kiếm nhị phân mà có thể tịnh tiến để tìm kết quả. Độ phức tạp cho phần QHĐ này là ~O(N)~.

Lưu ý: Các code mẫu dưới đây chỉ mang tính tham khảo và có thể không AC được bài tập này

Code mẫu của ladpro98

#include <bits/stdc++.h>
#define X first
#define Y second

const int N = 300005;

using namespace std;
typedef pair<long long, long long> Line;

int n;
pair<int, int> a[N];

long long eval(long long x, Line line) {
    return x * line.X + line.Y;
}

bool bad(Line d1, Line d2, Line d3) {
    return (d2.Y - d1.Y) * (d1.X - d3.X) >= (d3.Y - d1.Y) * (d1.X - d2.X);
}

int main() {
    scanf("%d", &n);
    int i;
    for (i = 1; i <= n; i++) scanf("%d %d", &a[i].X, &a[i].Y);
    sort(a + 1, a + 1 + n);
    vector<pair<int, int> > b;
    for (i = 1; i <= n; i++) {
        while (b.size() && b.back().Y < a[i].Y) b.pop_back();
        b.push_back(a[i]);
    }
    vector<Line> d; long long last = 0; Line new_line; int best = 0;
    for (i = 0; i < b.size(); i++) {
        new_line = Line(b[i].Y, last);
        while (d.size() >= 2 && bad(d[d.size() - 2], d[d.size() - 1], new_line)) {
            if (best >= d.size() - 1) best--; d.pop_back();
        }
        d.push_back(new_line);
        while (best + 1 < d.size() && 
        eval(b[i].X, d[best]) >= eval(b[i].X, d[best + 1])) best++;
        last = intersectX(b[i].X, d[best]);
    }
    cout << last << endl;
    return 0;
}

Code mẫu của flashmt

#include <iostream>
#include <algorithm>
#include <cstdio>
#include <vector>
#include <utility>
using namespace std;

int n,tail=1,head;
long long f[300300],A[300300],B[300300];
double start[300300];
vector < pair<int,int> > a;

void addLine(long long a,long long b)
{
    long long x=-1e18;
    while (tail>head)
    {
        x=1.0*(b-B[tail])/(A[tail]-a);
        if (x>start[tail]) break;
        tail--;
    }
    if (tail>=head) x=1.0*(b-B[tail])/(A[tail]-a);
    A[++tail]=a; B[tail]=b; start[tail]=x;
}

long long findLowest(int x)
{
    while (head<tail && start[head+1]<=1.0*x) head++;
    return A[head]*x+B[head];
}

int main()
{
    int x,y;
    cin >> n;
    a.push_back(make_pair(-1,-1));
    for (int i=1;i<=n;i++) scanf("%d%d",&x,&y), a.push_back(make_pair(x,y));
    sort(a.begin(),a.end());
    n=1;
    for (int i=2;i<int(a.size());i++)
    {
        while (n && a[i].second>=a[n].second) n--;
        a[++n]=a[i];
    }
    addLine(a[1].second,0);
    for (int i=1;i<=n;i++)
    {
        f[i]=findLowest(a[i].first);
        if (i<n) addLine(a[i+1].second,f[i]);
    }
    cout << f[n] << endl;
    return 0;
}

Code mẫu của RR

#include <iostream>
#include <algorithm>
#include <deque>
#include <cstdio>
#include <cstring>
#include <cmath>

#define MAXN 300111
#define FOR(i,a,b)  for(long i=a; i<=b; i++)
#define FORD(i,a,b) for(long i=a; i>=b; i--)
#define PB push_back
#define P pair<long,long>
#define MP make_pair
#define F first
#define S second
#define SZ(x) (x.size())
#define ll long long
using namespace std;

long debug=0;
long n;
bool erased[MAXN];
P a[MAXN];
deque<long> q;
ll f[MAXN];

void inp() {
    scanf("%ld",&n);
    FOR(i,1,n) scanf("%ld %ld",&a[i].F,&a[i].S);
}

void init() {
    sort(a+1,a+n+1);
    erased[n]=false;
    long now=a[n].S;
    FORD(i,n-1,1)
        if (a[i].S<=now) erased[i]=true;
        else {
            erased[i]=false;
            now=a[i].S;
        }

    long j=0;
    FOR(i,1,n)
        if (!erased[i])
            a[++j]=a[i];
    n=j;
    if (debug) FOR(i,1,n) cout<<a[i].F<<" "<<a[i].S<<endl;
}

long long g(long j,long k) {
    return f[j+1]-f[k+1];
}

long long h(long j,long k) {
    return a[k].F-a[j].F;
}

template <typename T> void write(T a) {
    FOR(i,0,SZ(a)-1) cout<<a[i]<<" ";
    cout<<endl;
}

void solve() {
    f[n]=(long long)a[n].F*a[n].S;
    q.PB(n);
    FORD(i,n-1,1) {
        if (debug) cout<<"i = "<<i<<endl;
        while (q.size()>1 && g(q[1],q[0])<=a[i].S*h(q[1],q[0])) q.pop_front();
        if (debug) write(q);
        f[i]=f[q[0]+1]+(long long)a[q[0]].F*a[i].S;

        while (q.size()>1 
            && g(i,q[SZ(q)-1])*h(q[SZ(q)-1],q[SZ(q)-2])
            <= g(q[SZ(q)-1],q[SZ(q)-2])*h(i,q[SZ(q)-1]) ) q.pop_back();

        f[i]=min(f[i],f[i+1]+(long long)a[i].F*a[i].S);
        q.PB(i);

        if (debug) cout<<"f = "<<f[i]<<endl; 
    }
    cout<<f[1]<<endl;
}

int main() {
//    freopen("input.txt","r",stdin);
//    freopen("output.txt","w",stdout);
    inp();
    init();
    solve();
    return 0;
}

Code mẫu của hieult

#include <cstdio>
//#include <conio.h>
#include <cstdlib>
#include <iostream>
#define Max 333333

using namespace std;

struct diem { long long x,y;};

struct line{
     long long m,b;
     line(long long x = 0, long long y = 0) { m = x;b = y;}
};

int N,ldau,lcuoi;
diem p[Max];
long long kq[Max];
line lines[Max];

int cmp(const void *a, const void *b)
{
    diem p = *(diem *)a, q = *(diem *)b;
    if(p.x<q.x) return -1;
    if(p.x>q.x) return 1;
    if(p.y<q.y) return -1;
    return 1;
}

bool xau(line x,line y,line z){
          return ((y.m-z.m)*(y.b-x.b)>=(x.m-y.m)*(z.b-y.b));
}

int main()
{
    scanf("%d",&N);
    for(int i = 0 ;i<N;i++)
         scanf("%lld %lld",&p[i].x,&p[i].y);
    qsort(p,N,sizeof(p[0]),cmp);

    int tN = 0;
    for(int i = 0;i<N;i++)
    {
            p[tN] = p[i];
            while(tN>0 && p[tN].x>=p[tN-1].x && p[tN].y>= p[tN-1].y){
            p[tN-1] = p[tN];
            tN--;
            }
            tN++;
    }
    N = tN;
    kq[0] = 0;
    lines[0] = line(p[0].y,kq[0]);
    ldau = 0;
    lcuoi = 1;
    for(int i = 0;i<N;i++){
            for(int j = ldau;j<lcuoi;j++){
                 long long tmp = lines[j].m*p[i].x+lines[j].b;
                 if(j == ldau) kq[i+1] = tmp;
                 else if(tmp<kq[i+1]){
                      kq[i+1] = tmp;
                      ldau = j;
                 }
                 else break;
            }
            if(i<N-1){
                  lines[lcuoi] = line(p[i+1].y,kq[i+1]);
                  lcuoi++;

                  while(lcuoi-ldau>=3 && xau(lines[lcuoi-3],lines[lcuoi-2],lines[lcuoi-1])){
                       lines[lcuoi-2] = lines[lcuoi-1];
                       lcuoi--;
                  }
            }
    }
    printf("%lld",kq[N]);
    //getch();
}

Code mẫu của ll931110

//#pragma comment(linker, "/STACK:16777216")
#include <algorithm>
#include <bitset>
#include <cmath>
#include <ctime>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <deque>
#include <fstream>
#include <functional>
#include <iostream>
#include <map>
#include <set>
#include <sstream>
#include <stack>
#include <queue>
#include <vector>
#include <utility>
using namespace std;

struct rect
{
    long long w,h;
};

struct line
{
    long long m,b;
    line(long long x = 0,long long y = 0)
    {
        m = x;  b = y;
    }  
};

bool cmp(rect A,rect B)
{
    if (A.w != B.w) return A.w < B.w;
    return A.h < B.h;
}

/*int cmp(rect p, rect q) {
  if (p.w < q.w) return -1;
  if (p.w > q.w) return 1;
  if (p.h < q.h) return -1;
  return 1;
}*/

bool cross_product(line x,line y,line z)
{
  return ((y.m-z.m) * (y.b-x.b) >= (x.m-y.m) * (z.b-y.b));    
/*    long long x1 = p.m - q.m,y1 = p.b - q.b;
    long long x2 = q.m - r.m,y2 = q.b - r.b;
    return x1 * y2 - x2 * y1;*/
}

int n;
rect plots[300010];
line lines[300010];
long long best[300010];

int main()
{
//    freopen("acquire.3.in","r",stdin);
//    freopen("acquire.ou","w",stdout);
    scanf("%d", &n);
    for (int i = 0; i < n; i++) scanf("%lld %lld", &plots[i].w, &plots[i].h);
    sort(plots,plots + n,cmp);

    int tn = 0;
    for (int i = 0; i < n; i++)
    {
        plots[tn] = plots[i];
        while (tn > 0 && plots[tn].w >= plots[tn - 1].w && plots[tn].h >= plots[tn - 1].h) 
        {
            plots[tn - 1] = plots[tn];  tn--;
        }
        tn++;
    }
    n = tn;
/*    for (int i = 0; i < n; i++) cout << plots[i].w << ' ';
    cout << endl;
    for (int i = 0; i < n; i++) cout << plots[i].h << ' ';
    cout << endl;*/
    best[0] = 0;
    lines[0] = line(plots[0].h,best[0]);
    int ls = 0,le = 1;
    for (int i = 0; i < n; i++)
    {
        for (int j = ls; j < le; j++)
        {
            long long tmp = lines[j].m * plots[i].w + lines[j].b;
            if (j == ls) best[i + 1] = tmp;
            else if (tmp < best[i + 1])
            {
                best[i + 1] = tmp;  ls = j;
            }
            else 
            {
//                cout << "ok: " << lines[ls].m << ' ' << plot[i].w << ' ' << lines[ls].b << endl;
                break;
            }
        }
//        cout << best[i + 1] << endl;
        if (i < n - 1)
        {
            lines[le] = line(plots[i + 1].h,best[i + 1]);  le++;
            while (le - ls >= 3 && cross_product(lines[le - 3],lines[le - 2],lines[le - 1]))
            {
                lines[le - 2] = lines[le - 1];  le--;
            }
        }        
    }       
    cout << best[n] << endl;
}

Comments

Please read the guidelines before commenting.


There are no comments at the moment.