树上差分

树上差分

# include <bits/stdc++.h>
using namespace std;

const int MAXN=1e5+100;
struct Edge{
   int to,w,next;
}edge[MAXN<<1];
int head[MAXN];
int tot=0;
void add(int u,int v,int w)
{
   edge[tot].to=v;
   edge[tot].w=w;
   edge[tot].next=head[u];
   head[u]=tot++;
}
int fa[MAXN],f[MAXN],d[MAXN],v[MAXN];
vector<int> query[MAXN];
int T,n,m,t;
int Find(int x)
{
   if(fa[x]!=x) fa[x]=Find(fa[x]);
   return fa[x];
}
void add_query(int x,int y)
{
   query[x].push_back(y);
   query[y].push_back(x);
}
void Tarjan(int x)
{
   v[x]=1;
   for(int i=head[x];~i;i=edge[i].next){
       Edge e=edge[i];
       if(v[e.to]) continue;
       Tarjan(e.to);
       fa[e.to]=x;
  }
   for(int i=0;i<query[x].size();i++){
       int y=query[x][i];
       if(v[y]==2){
           int lca=Find(y);
           f[x]++,f[y]++,f[lca]-=2;
      }
  }
   v[x]=2;
}
void dfs(int x,int fa)
{
   for(int i=head[x];~i;i=edge[i].next){
       Edge e=edge[i];
       if(e.to==fa) continue;
       dfs(e.to,x);
       f[x]+=f[e.to];
  }
}
int main()
{
   int n,m; scanf("%d%d",&n,&m);
   for(int i=0;i<=n;i++){
       head[i]=-1; fa[i]=i; v[i]=0;
       query[i].clear();
  }
   for(int i=1;i<n;i++){
       int u,v; scanf("%d%d",&u,&v);
       add(u,v,1); add(v,u,1);
  }
   for(int i=1;i<=m;i++){
       int u,v; scanf("%d%d",&u,&v);
       if(u==v) continue;
       else{
           add_query(u,v);
      }
  }
   Tarjan(1);
   long long ans=0;
   dfs(1,0);
   for(int i=2;i<=n;i++){
        if(f[i]==0){
           ans+=m;
        }else if(f[i]==1){
           ans++;
        }
  }
   printf("%lld\n",ans);
   return 0;
}
/*
9 2
1 2
1 3
1 4
2 5
2 6
4 7
4 8
7 9
6 7
8 9
*/



posted @ 2022-02-26 23:35  fengzlj  阅读(60)  评论(0)    收藏  举报