BZOJ 3697 采药人的路径

Posted on 2017-03-11 18:20  ziliuziliu  阅读(161)  评论(0编辑  收藏  举报

点分。随便记录下这个点到上面有没有len=0的点,然后直接统计答案。

1A了赞。

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#define maxv 100050
#define maxe 200050
#define inf 1000000007
using namespace std;
int n,x,y,z,g[maxv],nume=1,add,cnt1[maxv*30],cnt2[maxv*30],cnts[maxv*30],sum,root;
int size[maxv],mx[maxv],dis[maxv];
long long ans=0;
struct edge
{
    int v,w,nxt;
}e[maxe];
bool vis[maxv];
void addedge(int u,int v,int w)
{
    e[++nume].v=v;e[nume].w=(w==1)?1:-1;
    e[nume].nxt=g[u];g[u]=nume;
}
void get_root(int x,int fath)
{
    size[x]=1;mx[x]=0;
    for (int i=g[x];i;i=e[i].nxt)
    {
        int v=e[i].v;
        if (v==fath || vis[v]) continue;
        get_root(v,x);
        size[x]+=size[v];mx[x]=max(mx[x],size[v]);
    }
    mx[x]=max(mx[x],sum-size[x]);
    if (mx[x]<mx[root]) root=x;
}
void dfs1(int x,int fath,int type)
{
    cnts[dis[x]+add]++;
    if (type>0)
    {
        if (dis[x]==0) ans+=(long long)cnt1[add]+(cnts[add]>=2);
        else
        {
            if (cnts[dis[x]+add]>=2) ans+=(long long)cnt1[-dis[x]+add];
            else ans+=(long long)cnt2[-dis[x]+add];
        }
    }
    else
    {
        cnt1[dis[x]+add]++;
        if (cnts[dis[x]+add]>=2) cnt2[dis[x]+add]++;
    }
    for (int i=g[x];i;i=e[i].nxt)
    {
        int v=e[i].v;
        if (vis[v] || v==fath) continue;
        dis[v]=dis[x]+e[i].w;dfs1(v,x,type);
    }
    cnts[dis[x]+add]--;
}
void dfs2(int x,int fath)
{
    size[x]=1;cnt1[dis[x]+add]=cnt2[dis[x]+add]=0;
    for (int i=g[x];i;i=e[i].nxt)
    {
        int v=e[i].v;
        if (vis[v] || v==fath) continue;
        dfs2(v,x);
        size[x]+=size[v];
    }
}
void solve(int x)
{
    vis[x]=1;
    for (int i=g[x];i;i=e[i].nxt)
    {
        int v=e[i].v;
        if (vis[v]) continue;
        dis[v]=e[i].w;dfs1(v,0,1);
        dis[v]=e[i].w;dfs1(v,0,-1);
    }
    for (int i=g[x];i;i=e[i].nxt)
    {
        int v=e[i].v;
        if (vis[v]) continue;
        dis[v]=e[i].w;dfs2(v,0);
    }
    for (int i=g[x];i;i=e[i].nxt)
    {
        int v=e[i].v;
        if (vis[v]) continue;
        sum=size[v];root=0;get_root(v,0);
        solve(root);
    }
}
int main()
{
    scanf("%d",&n);add=n+1;mx[0]=inf;
    for (int i=1;i<=n-1;i++)
    {
        scanf("%d%d%d",&x,&y,&z);
        addedge(x,y,z);addedge(y,x,z);
    }
    sum=n;root=0;get_root(1,0);
    solve(root);
    printf("%lld\n",ans);
    return 0;
}