트 리 분할 + FFT

나 쁜 거리의 문 제 는 손 을 연습 하 는 것 일 뿐 이지 만 많은 문제 가 드 러 났 다. 그리고 나 서 나 는 시간 을 1s 정도 로 최적화 시 켜 서 야 그만 두 었 다.
http://www.codechef.com/problems/PRIMEDST
복잡 도 는 nlogn ^ 2 입 니 다. 통 계 를 합 칠 때마다 FFT 를 사 용 했 기 때문에 nlogn 입 니 다.
참고 로 컴 파일 러 의 차이 가 매우 크다.
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cmath>
#include <vector>
using std::vector;
typedef long long type;
struct comp{
    double x, y;
    comp(double _x=0, double _y=0) : x(_x), y(_y) {}
};
namespace FFT{
    const int N = 131072, MinSize = 400000;
    const double pi2 = 3.1415926535897932 * 2;
    comp a[N], b[N], tmp[N];
    int n, bn;
    type res[N];
    inline comp W(int n, bool inv) {
        double ang = inv ? -pi2 / n : pi2 / n;
        return comp(cos(ang), sin(ang));
    }
   inline  int bitrev(int x) {
        int ans = 0;
        for (int i=1; i<=bn; ++i)
            ans <<= 1, ans |= x & 1, x >>= 1;
        return ans;
    }
    void dft(comp *a,bool inv) {
        int step, to; comp w, wi, A, B;
        for (int i=0; i<n; ++i) {
            to = bitrev(i);
            if (to > i) std::swap(a[to], a[i]);
        }
        for (int i=1; i<=bn; ++i) {
            wi = W(1<<i, inv); w = comp(1, 0);
            step = 1 << (i-1);
            for (int k=0; k<step; ++k) {
                for (int j=0; j<n; j+=1<<i) {
                    int t = j | k, d = j|k|step;
                    A = a[t];
                    B.x  = w.x * a[d].x - w.y * a[d].y;
                    B.y  = w.x * a[d].y + w.y * a[d].x;
                    a[t].x = A.x + B.x, a[t].y = A.y + B.y;
                    a[d].x = A.x - B.x, a[d].y = A.y - B.y;
                }
                comp tmp;
                tmp.x = w.x * wi.x - w.y * wi.y;
                tmp.y = w.x * wi.y + w.y * wi.x;
                w = tmp;
            }
        }
    }
    int mul(int n1, int *x1, int n2, int *x2) {
        n = std::max(n1, n2);
        for (bn = 0; (1<<bn) < n; ++bn); ++bn;
        n = 1 << bn;
        for (int i=0; i<n1; ++i) a[i] = comp(x1[i], 0);
        for(int i=n1;i<n;i++) a[i] = comp(0,0);
        dft(a, false); 
        for (int i=0; i<n; ++i) {
            tmp[i].x = a[i].x * a[i].x - a[i].y * a[i].y;
            tmp[i].y = a[i].x * a[i].y + a[i].y * a[i].x;
        }
        dft(tmp, true);
        for (int i=0; i<n; ++i) res[i] = (type)(tmp[i].x/n + 0.1);
        for (--n; n && !res[n]; --n);
        return n+1;
    }
}
const int N = 50010;
bool vis[N];
int p[N],pn;
void init() {
    pn = 0;
    for(int i = 2; i < N; i++) {
        for(int j = i+i; j < N; j+=i) {
            vis[j] = true;
        }
    }
    for(int i = 2; i < N; i++) if(!vis[i]) p[pn++] = i;
}
int head[N],nxt[N*2],pnt[N*2];
int E,n;
void add(int a,int b) {
    pnt[E] = b;
    nxt[E] = head[a];
    head[a] = E++;
}
bool del[N];
int son[N],opt[N];
vector<int> alln;
void dfs(int u,int f) {
    alln.push_back(u);
    son[u] = 1;
    opt[u] = 0;
    for(int i = head[u]; i!=-1; i = nxt[i]) if(pnt[i]-f){
        if(del[pnt[i]]) continue;
        dfs(pnt[i],u);
        son[u] += son[pnt[i]];
        opt[u] = std::max(opt[u],son[pnt[i]]);
    }
}
int getcenter(int u) {
    alln.clear();
    dfs(u,-1);
    int mx = 0, ans = -1;
    int sz = alln.size();
    for(int i = 0; i < sz; i++)  {
        int v = alln[i];
        if(ans == -1) ans = v, mx = std::max(opt[v],sz-son[v]);
        else  {
            if(std::max(opt[v],sz-son[v]) < mx) {
                mx = std::max(opt[v],sz-son[v]);
                ans = v;
            }
        }
    }
    return ans;
}
int tot;
int D[N];
void getdist(int u,int f,int prew) {
    D[tot++] = prew;
    for(int i = head[u]; i != -1; i = nxt[i]) if(pnt[i]-f) {
        if(del[pnt[i]]) continue;
        getdist(pnt[i],u,prew+1);
    }
}
int cnt[50010];
inline long long calc() { // calcluate how many pair's sum of D[]  is a prime
    int len = *std::max_element(D,D+tot) + 1;
    std::fill(cnt,cnt+len,0);
    for(int i = 0; i < tot; i++) cnt[D[i]] ++;
    len = FFT::mul(len,cnt,len,cnt);
    for(int i = 0; i < tot; i++) FFT::res[D[i]+D[i]]--;
    for(int i = 0; i<pn && p[i] < len ; i++) FFT::res[p[i]] /= 2;
    long long ans = 0;
    for(int i = 0; i<pn && p[i] < len ; i++) ans += FFT::res[p[i]];
    return ans;
}
long long ans;
void solve(int u) {
    u = getcenter(u);
    tot=0;getdist(u,-1,0);
    ans += calc();
    for(int i = head[u]; i != -1; i = nxt[i]) {
        if(del[pnt[i]]) continue;
        tot = 0;
        getdist(pnt[i],u,1);
        ans -= calc();
    }
    del[u] = true;
    for(int i = head[u]; i != -1; i = nxt[i]) {
        if(del[pnt[i]]) continue;
        solve(pnt[i]);
    }
}
int main(){
    init();
    while(scanf("%d",&n)!=EOF) {
        E = 0; 
        std::fill(head,head+n+1,-1);
        std::fill(del,del+n+1,false);
        for(int i = 1,a,b; i < n; i++) {
            scanf("%d%d",&a,&b);
            add(a,b);
            add(b,a);
        }
        ans=0;
        solve(1);
        printf("%.8f
",1.0*ans*2/n/(n-1)); } return 0; }

좋은 웹페이지 즐겨찾기