Hướng dẫn giải của Vị trí phong thủy


Chỉ dùng lời giải này khi không có ý tưởng, và đừng copy-paste code từ lời giải này. Hãy tôn trọng người ra đề và người viết lời giải.
Nộp một lời giải chính thức trước khi tự giải là một hành động có thể bị ban.

Dễ thấy đáp án = số cặp đỉnh có khoảng cách giữa chúng là số nguyên tố / ~(n * (n-1) / 2)~

Subtask 1:

Các bạn có thể DFS từ từng đỉnh để thu được khoảng cách từ các đỉnh còn lại đến đỉnh đấy, sau đó dùng mảng tính trước để kiểm tra xem mỗi khoảng cách đó có phải là số nguyên tố hay không.

Độ phức tạp: ~O(n^2)~

Subtask 2:

Chúng ta có thể dùng centroid decomposition. Tại mỗi bước, chúng ta sẽ xét các cặp đỉnh mà đường đi giữa chúng có đi qua centroid.

Gọi ~d_x~ là khoảng cách từ đỉnh ~x~ đến centroid. Khi đó, bài toán trở thành đếm số cặp ~(a,b)~ sao cho

  • ~a~ và ~b~ không nằm trên cùng cây con khi bỏ centroid

  • ~d_a + d_b~ là số nguyên tố

Để tính số cặp (a,b) thỏa mãn ~d_a + d_b~ là số nguyên tố, chúng ta có thể lập đa thức ~f(x) = \sum a_i*x^i~, trong đó ~a_i~ là số đỉnh ~u~ có ~d_u = i~, sau đó dùng FFT/NTT để tính ~f*f~ và lấy tổng hệ số của các bậc là số nguyên tố trong ~f*f~.

Để bỏ đi các cặp ~(a,b)~ mà ~a~ và ~b~ nằm trên cùng cây con, chúng ta có thể lặp lại quá trình trên cho từng cây con.

Độ phức tạp: ~O(n*log^2(n))~

#ifndef CPL_TEMPLATE
#define CPL_TEMPLATE
/*
    Template for solving centroid decomp problems.
*/
// Standard library in one include.
#include <bits/stdc++.h>
using namespace std;

// ordered_set library.
// #include <ext/pb_ds/assoc_container.hpp>
// #include <ext/pb_ds/tree_policy.hpp>
// using namespace __gnu_pbds;
// #define ordered_set(el) tree<el,null_type,less<el>,rb_tree_tag,tree_order_statistics_node_update>

// AtCoder library. (Comment out these two lines if you're not submitting in AtCoder.) (Or if you want to use it in other judges, run expander.py first.)
//#include <atcoder/all>
//using namespace atcoder;

//Pragmas (Comment out these three lines if you're submitting in szkopul or USACO.)
// #pragma comment(linker, "/stack:200000000")
// #pragma GCC optimize("Ofast,unroll-loops,tree-vectorize")
// #pragma GCC target("sse,sse2,sse3,ssse3,sse4,popcnt,abm,mmx,avx,avx2,tune=native")

//File I/O.
#define FILE_IN "cseq.inp"
#define FILE_OUT "cseq.out"
#define ofile freopen(FILE_IN,"r",stdin);freopen(FILE_OUT,"w",stdout)

//Fast I/O.
#define fio ios::sync_with_stdio(0);cin.tie(0)
#define nfio cin.tie(0)
#define endl "\n"

//Order checking.
#define ord(a,b,c) ((a>=b)and(b>=c))

//min/max redefines, so i dont have to resolve annoying compile errors.
#define min(a,b) (((a)<(b))?(a):(b))
#define max(a,b) (((a)>(b))?(a):(b))

// Fast min/max assigns to use with AVX.
// Requires g++ 9.2.0.
// template<typename T>
// __attribute__((always_inline)) void chkmin(T& a, const T& b) {
//     a=(a<b)?a:b;
// }

// template<typename T>
// __attribute__((always_inline)) void chkmax(T& a, const T& b) {
//     a=(a>b)?a:b;
// }

//Constants.
#define MOD (ll(998244353))
#define MAX 300001
#define mag 320
const long double PI=3.14159265358979;

//Pairs and 3-pairs.
#define p1 first
#define p2 second.first
#define p3 second.second
#define fi first
#define se second
#define pii(element_type) pair<element_type,element_type>
#define piii(element_type) pair<element_type,pii(element_type)>

//Quick power of 2.
#define pow2(x) (ll(1)<<x)

//Short for-loops.
#define ff(i,__,___) for(int i=__;i<=___;i++)
#define rr(i,__,___) for(int i=__;i>=___;i--)

//Typedefs.
#define bi BigInt
typedef long long ll;
typedef long double ld;
typedef short sh;

// Binpow and stuff
ll BOW(ll a, ll x, ll p)
{
    if (!x) return 1;
    ll res=BOW(a,x/2,p);
    res*=res;
    res%=p;
    if (x%2) res*=a;
    return res%p;
}
ll INV(ll a, ll p)
{
    return BOW(a,p-2,p);
}
//---------END-------//
#endif


namespace CPL_NTT
{
    ll Mod=0;
    ll Root=0;
    ll Level=0;
    vector<ll> roots;
    ll inv2;
    ll bow(ll a, ll x, ll p) // exponentation by squaring
    {
        if (!x) return 1;
        ll res=bow(a,x/2,p);
        res*=res;
        res%=Mod;
        if (x%2) res*=a;
        return res%Mod;
    }
    void generate() // Generate roots of unity
    {
        roots.clear();
        ll u=1;
        for (ll i=0;i<Level;i++) 
        {
            roots.push_back(u);
            u*=Root;
            u%=Mod;
        }
        inv2=bow(2,Mod-2,Mod);
    }
    vector<ll> transform(const vector<ll>& vec, int inv) //Fourier ttransform
    {
        if (vec.size()==1) return {vec[0]};
        ll coeff=Level/vec.size(),lvl=log2(vec.size());
        vector<ll> res;
        for (ll i=0;i<vec.size();i++) // Building reverse-bit permutation of original array
        {
            ll u=0;
            for (ll j=0;j<lvl;j++) u^=(((i&(1<<j))>>j)<<(lvl-1-j));
            res.push_back(vec[u]);
        }
        for (ll t=0;t<lvl;t++)
        {
            coeff=Level/(1<<(t+1));
        for (ll i=0;i<vec.size();i+=(1<<(t+1)))
        {   // Apply merge for this segment
            for (ll j=0;j<(1<<t);j++)
            {
                ll a=res[i+j];
                ll b=res[i+j+(1<<t)];
                if (!inv)
                {
                res[i+j]=(a+roots[coeff*j]*b)%Mod;
                res[i+j+(1<<t)]=((a-roots[coeff*j]*b)%Mod+Mod)%Mod;
                }
                else
                {
                res[i+j]=(a+roots[(Level-coeff*j)%Level]*b)%Mod;
                res[i+j+(1<<t)]=((a-roots[(Level-coeff*j)%Level]*b)%Mod+Mod)%Mod;
                }
            }
        }
        }
        if (inv)
        {
            ll mul=bow(inv2,lvl,Mod);
            for (ll i=0;i<res.size();i++) res[i]=(res[i]*(mul))%Mod;
        }
        return res;
    }
    vector<ll> multiply(vector<ll> a, vector<ll> b) // Actual multiplication
    {
        ll u=1;
        while((u<a.size())or(u<b.size())) u*=2;
        u*=2;
        while(a.size()<u) a.push_back(0);
        while(b.size()<u) b.push_back(0);
        //for (auto g : a) cout<<g<<' '; cout<<endl;
        //for (auto g : b) cout<<g<<' '; cout<<endl;
        vector<ll> ra=transform(a,0),rb=transform(b,0);
        for (ll i=0;i<u;i++)
        {
        //  cout<<i<<' '<<ra[i]<<' '<<rb[i]<<endl;
            ra[i]=((ra[i]*rb[i])%Mod);
        }
        vector<ll> res=transform(ra,1);
        return res;
    }
};


set<pii(int)> gt[200001];
ll n,m,i,j,k,t,t1,u,v,a,b;
ll eu[200001],ev[200001],ec[200001];

ll sz[200001];
ll dep[200001];

ll pr[200001];

ll cnt = 0;

vector<ll> a1, a2;

void preJob(int x) {
    sz[x]=1;
    for (auto g : gt[x]) if (!sz[g.fi]) {
        dep[g.fi] = dep[x] + 1;
        preJob(g.fi);
        sz[x]+=sz[g.fi];
    }
}

void postJob(int x) {
    sz[x]=0;
    dep[x]=0;
    for (auto g : gt[x]) if (sz[g.fi]) {
        postJob(g.fi);
    }
}



void dfs(int x) {
    if (dep[x]>=a1.size()) a1.push_back(0);
    a1[dep[x]]++;

    if (dep[x]>=a2.size()) a2.push_back(0);
    a2[dep[x]]++;

    for (auto g : gt[x]) if (sz[g.fi] < sz[x]) {
        dfs(g.fi);
    }
}

ll calc(vector<ll> lmao) {
    // for (auto g : lmao) cout<<g<<' '; cout<<endl;
    auto prod = CPL_NTT::multiply(lmao,lmao);
    // for (auto g : prod) cout<<g<<' '; cout<<endl;
    ll res = 0;
    for (int i = 0; i < prod.size(); i++) res+=prod[i]*pr[i];
    return res;
}

void solve(int x) {
    preJob(x);

    int n = sz[x];

    int rt,v,p;
    rt=x;
    p=0;
    while(true) {
        v=0;
        for (auto g : gt[rt]) if (g.fi!=p && sz[g.fi]>sz[v]) v=g.fi;
        if (sz[v]*2<=n && sz[rt]*2>=n) break;
        else {
            p=rt;
            rt=v;
        }
    }

    postJob(x);
    preJob(rt);

    // START SOLVING CODE

    a1.clear(); a1.push_back(1);
    for (auto g : gt[rt]) {
        // cout<<rt<<' '<<g.fi<<endl;
        a2.clear();
        a2.push_back(0);
        dfs(g.fi);
        cnt-=calc(a2);
    }

    cnt+=calc(a1);

    // END SOLVING CODE

    postJob(rt);

    for (auto g : gt[rt]) {
        gt[g.fi].erase({rt,g.se});
        solve(g.fi);
    }
}

int main()
{
    fio;
    cin>>n;

    for(i=2;i<=n;i++) {
        pr[i] = 1;
    }
    for (i=2;i*i<=n;i++) if (pr[i]) {
        for (j=i*i;j<=n;j+=i) {
            pr[j] = 0;
        }
    }

    // I have to use this cursed modulo since 50000^2>998244353

    CPL_NTT::Mod = 2533359617;
    CPL_NTT::Root = CPL_NTT::bow(3,151,CPL_NTT::Mod);
    CPL_NTT::Level = 1<<24;
    CPL_NTT::generate();

    for (i=1;i<n;i++) {
        cin>>eu[i]>>ev[i];
        gt[eu[i]].insert({ev[i],0});
        gt[ev[i]].insert({eu[i],0});
    }

    solve(1);

    // cout<<cnt<<endl;
    cout<<fixed<<setprecision(9)<<(ld)cnt/n/(n-1);
}

// a;

Bình luận

Hãy đọc nội quy trước khi bình luận.


Không có bình luận tại thời điểm này.