BZOJ3451. Tyvj1953 Normal

传送门

考虑每个点 $i$ 对答案的贡献

当删去一个节点 $j$ 的时候, $i$ 会对 $j$ 产生 $1$ 的贡献当且仅当 $i,j$ 这条链上的所有点中,$j$ 是第一个删除的节点

显然链上每个节点第一个被删除的概率是一样的

所以点对 $i,j$ 的贡献就是 $\frac{1}{dis(i,j)}$,其中 $dis(i,i)=1$

那么答案就相当于 $\sum_{i}\sum_{j}\frac{1}{dis(i,j)}$

发现可以转化为求,对于每个值 $k$ ,$dis=k$ 的点对的数量

显然直接点分治

对于每一个分治节点,统计跨过它的各种长度的路径数量

设 $A[k]$ 表示当前节点 $x$ 的点分子树内所有到 $x$ 的路径长度恰好为 $k$ 的路径数量,设 $B[k]$ 为跨过 $x$ 的两点路径长度恰好为 $k$ 的路径的数量

则有 $B[k]=\sum_{j=0}^{k}A[j]A[k-j]$,是卷积的形式,可以 $FFT$ 优化

这样没有考虑两点在 $x$ 的同一儿子 $v$ 子树内的不合法情况,但是可以直接用同样的方法算出对于 $v$ 的 $B[]$,减一下就行了

这样总复杂度 $O(nlog_{n}^{2})$

只要熟悉点分治和 $FFT$ ,代码不难理解

#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<cmath>
using namespace std;
typedef long long ll;
typedef long double ldb;
inline int read()
{
    int x=0,f=1; char ch=getchar();
    while(ch<'0'||ch>'9') { if(ch=='-') f=-1; ch=getchar(); }
    while(ch>='0'&&ch<='9') { x=(x<<1)+(x<<3)+(ch^48); ch=getchar(); }
    return x*f;
}
const int N=4e5+7;
const ldb pi=acos(-1.0);
struct CP {
    ldb x,y;
    CP (ldb xx=0,ldb yy=0) { x=xx,y=yy; }
    inline CP operator + (const CP &tmp) const { return CP(x+tmp.x,y+tmp.y); }
    inline CP operator - (const CP &tmp) const { return CP(x-tmp.x,y-tmp.y); }
    inline CP operator * (const CP &tmp) const { return CP(x*tmp.x-y*tmp.y,x*tmp.y+y*tmp.x); }
}A[N];
int n,p[N];
void FFT(CP *A,int len,int type)
{
    for(int i=0;i<len;i++) if(i<p[i]) swap(A[i],A[p[i]]);
    for(int mid=1;mid<len;mid<<=1)
    {
        CP wn(cos(pi/mid),type*sin(pi/mid));
        for(int R=mid<<1,j=0;j<len;j+=R)
        {
            CP w(1,0);
            for(int k=0;k<mid;k++,w=w*wn)
            {
                CP x=A[j+k],y=w*A[j+mid+k];
                A[j+k]=x+y;
                A[j+mid+k]=x-y;
            }
        }
    }
}
int fir[N],from[N<<1],to[N<<1],cntt;
inline void add(int a,int b) { from[++cntt]=fir[a]; fir[a]=cntt; to[cntt]=b; }
int sz[N],mx[N],rt,tot;
bool vis[N];
void find_rt(int x,int fa)
{
    sz[x]=1; mx[x]=0;
    for(int i=fir[x];i;i=from[i])
    {
        int &v=to[i]; if(v==fa||vis[v]) continue;
        find_rt(v,x); sz[x]+=sz[v];
        mx[x]=max(mx[x],sz[v]);
    }
    mx[x]=max(mx[x],tot-sz[x]);
    if(mx[x]<mx[rt]) rt=x;
}
int st[N],Top;
void dfs(int x,int fa,int dis)
{
    st[++Top]=dis;
    for(int i=fir[x];i;i=from[i])
        if(to[i]!=fa&&!vis[to[i]]) dfs(to[i],x,dis+1);
}
ll ans[N];
void calc(int type)
{
    int mx=0,len=1,tot=0;
    for(int i=1;i<=Top;i++) mx=max(mx,st[i]);
    while(len<=2*mx) len<<=1,tot++;
    for(int i=0;i<=len;i++) A[i]=CP(0,0);
    for(int i=1;i<=Top;i++) A[st[i]].x++;
    for(int i=0;i<len;i++) p[i]=(p[i>>1]>>1)|((i&1)<<(tot-1));
    FFT(A,len,1);
    for(int i=0;i<=len;i++) A[i]=A[i]*A[i];
    FFT(A,len,-1);
    for(int i=0;i<=mx*2;i++) ans[i]+=1ll*type*ll(A[i].x/len+0.5);
}
void solve(int x)
{
    vis[x]=1; Top=0; dfs(x,0,0); calc(1);
    for(int i=fir[x];i;i=from[i])
    {
        int &v=to[i]; if(vis[v]) continue;
        Top=0; dfs(v,x,1); calc(-1);
        rt=0; tot=sz[v]; find_rt(v,x); solve(rt);
    }
}
ldb Ans=0;
int main()
{
    n=read(); int a,b;
    for(int i=1;i<n;i++)
    {
        a=read()+1,b=read()+1;
        add(a,b),add(b,a);
    }
    tot=n; mx[0]=2333333; find_rt(1,0); solve(rt);
    for(int i=0;i<n;i++) Ans+=(ldb)ans[i]/(i+1);
    printf("%.4Lf\n",Ans);
    return 0;
}

 

posted @ 2019-07-27 14:29  LLTYYC  阅读(292)  评论(0编辑  收藏  举报