BZOJ3127:[USACO2013OPEN]Yin and Yang

浅谈树分治:https://www.cnblogs.com/AKMer/p/10014803.html

题目传送门:https://www.lydsy.com/JudgeOnline/problem.php?id=3127

这题很合我的胃口,要我统计树上的“太极”路径。

假设没有要求中间某个点到两端黑白也是相等的,那么我们直接用\(f[i]\)记录黑白两种边的差值为\(i\)的路径有多少条就行了。但是要求路径上需要存在中间某个点到两个端点的黑白边也是相同的。那么我们就用\(f[i][0]\)记录差值为\(i\),不存在这种点的路径条数,\(f[i][1]\)记录存在这种点的路径条数。然后根据根到当前点的路径上有无这种点,分别用\(f[-num][0]+f[-num][1]\)\(f[-num][1]\)更新即可,\(num\)表示根到当前点的路径上黑白边的差值。如果判断路径上是否有这种点呢?记一个全局数组\(sum\)\(sum[i]\)表示在当前递归的栈里,从根到这个点路径上黑白边差值为\(i\)的点有多少个。如果\(sum[num]\)不为零,那么从根到当前点的路径上必然会有一个这样的点。

时间复杂度:\(O(nlogn)\)

空间复杂度:\(O(n)\)

点分治版代码如下:

#include <cstdio>
#include <algorithm>
using namespace std;
typedef long long ll;
typedef pair<int,int> pii;
#define fr first
#define sc second
 
const int maxn=1e5+5;
 
ll ans;
bool vis[maxn];
int n,tot,mx,rt,N;
int sum[maxn<<1],f[maxn<<1][2],siz[maxn];
int now[maxn],pre[maxn*2],son[maxn*2],val[maxn*2];
 
int read() {
    int x=0,f=1;char ch=getchar();
    for(;ch<'0'||ch>'9';ch=getchar())if(ch=='-')f=-1;
    for(;ch>='0'&&ch<='9';ch=getchar())x=x*10+ch-'0';
    return x*f;
}
 
void add(int a,int b,int c) {
    pre[++tot]=now[a];
    now[a]=tot,son[tot]=b,val[tot]=c;
}
 
struct rubbish {
    int top;
    pii sta[maxn];
    bool bo[maxn<<1][2];
 
    void clear() {
        while(top) {
            int num1=sta[top].fr,num2=sta[top].sc;
            bo[num1][num2]=0,f[num1][num2]=0,top--;
        }
    }
 
    void ins(int a,int b) {
        if(bo[a][b])return;
        bo[a][b]=1;
        sta[++top]=make_pair(a,b);
    }
}R;
 
void find_son(int fa,int u) {
    int res=0;siz[u]=1;
    for(int p=now[u],v=son[p];p;p=pre[p],v=son[p])
        if(!vis[v]&&v!=fa)find_son(u,v),siz[u]+=siz[v],res=max(res,siz[v]);
    res=max(res,N-siz[u]);
    if(res<mx)mx=res,rt=u;
}
 
void query(int fa,int u,int num) {
    if(sum[num+maxn]||num==0)ans+=f[-num+maxn][0]+f[-num+maxn][1];
    else ans+=f[-num+maxn][1];
    sum[num+maxn]++;
    for(int p=now[u],v=son[p];p;p=pre[p],v=son[p])
        if(!vis[v]&&v!=fa)query(u,v,num+val[p]);
    sum[num+maxn]--;
}
 
void solve(int fa,int u,int num) {
    f[num+maxn][sum[num+maxn]!=0]++;
    R.ins(num+maxn,sum[num+maxn]!=0);
    sum[num+maxn]++,siz[u]=1;
    for(int p=now[u],v=son[p];p;p=pre[p],v=son[p])
        if(!vis[v]&&v!=fa)solve(u,v,num+val[p]),siz[u]+=siz[v];
    sum[num+maxn]--;
}
 
void work(int u,int size) {
    N=size,mx=rt=n+1,find_son(0,u),u=rt,vis[u]=1;
    for(int p=now[u],v=son[p];p;p=pre[p],v=son[p])
        if(!vis[v]) {
            query(u,v,val[p]);
            solve(u,v,val[p]);
        }
    ans+=f[maxn][1];R.clear();
    for(int p=now[u],v=son[p];p;p=pre[p],v=son[p])
        if(!vis[v])work(v,siz[v]);
}
 
int main() {
    n=read();
    for(int i=1;i<n;i++) {
        int a=read(),b=read(),c=(read()<<1)-1;
        add(a,b,c),add(b,a,c);
    }work(1,n);printf("%lld\n",ans);
    return 0;
}

边分治版代码如下:

#include <cstdio>
#include <vector>
#include <cstring>
#include <algorithm>
using namespace std;
typedef long long ll;
typedef pair<int,int> pii;
#define fr first
#define sc second
 
const int maxn=2e5+5;
 
ll ans;
bool vis[maxn];
int n,tot,mx,id,N;
int siz[maxn],f[maxn<<1][2],sum[maxn<<1];
int now[maxn],pre[maxn*2],son[maxn*2],val[maxn*2];
 
vector<pii>to[maxn];
vector<pii>::iterator it;
 
int read() {
    int x=0,f=1;char ch=getchar();
    for(;ch<'0'||ch>'9';ch=getchar())if(ch=='-')f=-1;
    for(;ch>='0'&&ch<='9';ch=getchar())x=x*10+ch-'0';
    return x*f;
}
 
void add(int a,int b,int c) {
    pre[++tot]=now[a];
    now[a]=tot,son[tot]=b,val[tot]=c;
}
 
void find_son(int fa,int u) {
    for(int p=now[u],v=son[p];p;p=pre[p],v=son[p])
        if(v!=fa)to[u].push_back(make_pair(v,val[p])),find_son(u,v);
}
 
void rebuild() {
    tot=1;memset(now,0,sizeof(now));
    for(int i=1;i<=n;i++) {
        int size=to[i].size();
        if(size<=2) {
            for(it=to[i].begin();it!=to[i].end();it++) {
                pii tmp=*it;
                add(i,tmp.fr,tmp.sc),add(tmp.fr,i,tmp.sc);
            }
        }
        else {
            pii u1=make_pair(++n,0),u2;
            if(size==3)u2=to[i].front();
            else u2=make_pair(++n,0);
            add(i,u1.fr,u1.sc),add(u1.fr,i,u1.sc);
            add(i,u2.fr,u2.sc),add(u2.fr,i,u2.sc);
            if(size==3) {
                for(int j=1;j<3;j++)
                    to[n].push_back(to[i].back()),to[i].pop_back();
            }
            else {
                int p=0;
                for(it=to[i].begin();it!=to[i].end();it++) {
                    if(!p)to[n-1].push_back(*it);
                    else to[n].push_back(*it);p^=1;
                }
            }
        }
    }
}
 
struct rubbish {
    int top;
    pii sta[maxn];
    bool bo[maxn<<1][2];
 
    void clear() {
        while(top) {
            int num1=sta[top].fr,num2=sta[top].sc;
            bo[num1][num2]=0,f[num1][num2]=0,top--;
        }
    }
 
    void ins(int a,int b) {
        if(bo[a][b])return;
        bo[a][b]=1,sta[++top]=make_pair(a,b);
    }
}R;
 
void find_edge(int fa,int u) {
    siz[u]=1;
    for(int p=now[u],v=son[p];p;p=pre[p],v=son[p])
        if(!vis[p>>1]&&v!=fa) {
            find_edge(u,v),siz[u]+=siz[v];
            if(abs(N-2*siz[v])<mx)
                mx=abs(N-2*siz[v]),id=p>>1;
        }
}
 
void solve(int fa,int u,int num,bool bo) {
    if(bo) {
        f[num+maxn][(sum[num+maxn]!=0)]++;
        R.ins(num+maxn,sum[num+maxn]!=0),sum[num+maxn]++;
    }siz[u]=1;
    for(int p=now[u],v=son[p];p;p=pre[p],v=son[p])
        if(!vis[p>>1]&&v!=fa)solve(u,v,num+val[p],val[p]!=0),siz[u]+=siz[v];
    if(bo)sum[num+maxn]--;
}
 
void query(int fa,int u,int num,bool bo) {
    if(bo) {
        if(!sum[num+maxn])ans+=f[-num+maxn][1];
        else {
            ans+=f[-num+maxn][0]+f[-num+maxn][1];
            if(val[id<<1]==0&&num==0)ans--;//这种情况就是,在u2所在的联通块里找到了一条满足条件的路径,但是会在这里被f[-num+maxn][0],也就是0号点到u1的路径匹配上算一次,到时候去处理u2所在的联通块会被重复计算,所以就减掉了。
        }
        sum[num+maxn]++;
    }siz[u]=1;
    for(int p=now[u],v=son[p];p;p=pre[p],v=son[p])
        if(!vis[p>>1]&&v!=fa)query(u,v,num+val[p],val[p]!=0),siz[u]+=siz[v];
    if(bo)sum[num+maxn]--;
}
 
void work(int u,int size) {
    if(size<2)return;
    N=size,mx=id=n+1,find_edge(0,u),vis[id]=1;
    int u1=son[id<<1],u2=son[id<<1|1];
    solve(0,u1,0,1),query(0,u2,val[id<<1],(val[id<<1]!=0));
    R.clear(),work(u1,siz[u1]),work(u2,siz[u2]);
}
 
int main() {
    n=read();
    for(int i=1;i<n;i++) {
        int a=read(),b=read(),c=(read()<<1)-1;
        add(a,b,c),add(b,a,c);
    }find_son(0,1),rebuild();
    work(1,n),printf("%lld\n",ans);
    return 0;
}
posted @ 2018-12-14 15:08  AKMer  阅读(212)  评论(0编辑  收藏  举报