bzoj3697 采药人的路径

题目描述

题解:

首先我们应该注意,这道题问的是:

对于点对(a,b),存在点c在ab路径上,且a<->c和b<->c都是阴阳平衡的合法点对(a,b)有多少对。

因此这玩意是树链统计。

阴阳平衡就是$1+(-1)=0$;

用点分治搞一搞。

仔细看一看,你很快发现如果a->b和a->b->c相等的话,b<->c一定是阴阳平衡的(废话)。

所以我们将状态分为两种,路径上没有阴阳平衡的,还有路径上没有阴阳平衡的。

所以代码:

#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
#define N 100050
#define ll long long
inline int rd()
{
    int f=1,c=0;char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
    while(ch>='0'&&ch<='9'){c=10*c+ch-'0';ch=getchar();}
    return f*c;
}
int n,hed[N],cnt;
struct EG
{
    int to,nxt,v;
}e[2*N];
void ae(int f,int t,int v)
{
    e[++cnt].to = t;
    e[cnt].nxt = hed[f];
    e[cnt].v = v;
    hed[f] = cnt;
}
int w[N],rt,sum,mrk[N];
int siz[N];
ll ans;
void get_rt(int u,int fa)
{
    w[u] = 0,siz[u] = 1;
    for(int j=hed[u];j;j=e[j].nxt)
    {
        int to = e[j].to;
        if(to==fa||mrk[to])continue;
        get_rt(to,u);
        siz[u] += siz[to];
        if(siz[to]>w[u])w[u]=siz[to];
    }
    w[u] = max(w[u],sum-siz[u]);
    if(w[u]<w[rt])rt=u;
}
ll f[2*N][2],g[2*N][2];
int hs[2*N],max_dep;
void dfs(int u,int fa,int dep)
{
    max_dep = max(max_dep, (dep-N) * (dep<N?-1:1));
    if(hs[dep])f[dep][1]++;
    else f[dep][0]++;
    hs[dep]++;
    for(int j=hed[u];j;j=e[j].nxt)
    {
        int to = e[j].to;
        if(to==fa||mrk[to])continue;
        dfs(to,u,dep+e[j].v);
    }
    hs[dep]--;
}
void work(int u)
{
    mrk[u] = 1;g[N][0] = 1;int mxd = 0;
    for(int j=hed[u];j;j=e[j].nxt)
    {
        int to = e[j].to;
        if(mrk[to])continue;
        max_dep = 0;
        dfs(to,u,N+e[j].v);
        mxd = max(max_dep,mxd);
        ans+=f[N][0]*(g[N][0]-1);
        for(int i=-max_dep;i<=max_dep;i++)
            ans+= f[N+i][0]*g[N-i][1]+f[N+i][1]*g[N-i][0]+f[N+i][1]*g[N-i][1];
        for(int i=N-max_dep;i<=N+max_dep;i++)
        {
            g[i][0]+=f[i][0];
            g[i][1]+=f[i][1];
            f[i][0]=f[i][1]=0;
        }
    }
    for(int i=N-mxd;i<=N+mxd;i++)g[i][0]=g[i][1]=0;
    int sum0 = sum;
    for(int j=hed[u];j;j=e[j].nxt)
    {
        int to = e[j].to;
        if(mrk[to])continue;
        rt = 0,sum = (siz[to]>siz[u]?sum0-siz[u]:siz[to]);
        get_rt(to,0);
        work(rt);
    }
}
int main()
{
    n = rd();
    for(int f,t,v,i=1;i<n;i++)
    {
        f = rd(),t = rd(),v = rd();
        if(!v)v=-1;
        ae(f,t,v),ae(t,f,v);
    }
    w[0]=0x3f3f3f3f;
    rt = 0,sum = n;
    get_rt(1,0);
    work(rt);
    printf("%lld\n",ans);
    return 0;
}

 

posted @ 2018-12-28 13:28  LiGuanlin  阅读(126)  评论(0编辑  收藏  举报