没写完
code:
#include <bits/stdc++.h>
#define N 200004
#define ll long long
#define mod 1000000007
#define setIO(s) freopen(s".in","r",stdin)
using namespace std;
struct Lowbit
{
ll M[N];
int C[N];
int lowbit(int t)
{
return t&(-t);
}
void add(int x,int v)
{
while(x<N) C[x]+=v, x+=lowbit(x);
}
void mul(int x,int v)
{
while(x<N) M[x]=(ll)M[x]*v%mod,x+=lowbit(x);
}
int qsum(int x)
{
int re=0;
while(x>0) re+=C[x],x-=lowbit(x);
return re;
}
int qmul(int x)
{
ll re=1ll;
while(x>0) re=re*M[x]%mod,x-=lowbit(x);
return re;
}
void clr(int x)
{
while(x<N) M[x]=1,x+=lowbit(x);
}
}addv,mulv;
ll ans=1ll;
int root,edges,sn,tot;
int hd[N],to[N<<1],nex[N<<1],col[N<<1],val[N<<1],size[N],mx[N],vis[N];
void addedge(int u,int v,int vv,int c)
{
nex[++edges]=hd[u],hd[u]=edges,to[edges]=v,val[edges]=vv,col[edges]=c;
}
struct node
{
int x,y,op,v;
node(int x=0,int y=0,int op=0,int v=0):x(x),y(y),op(op),v(v){}
}q[N];
bool cmp(node a,node b)
{
return (a.x==b.x&&a.y==b.y)?(a.z>b.z):(a.x==b.x?a.y<b.y:a.x<b.x);
}
inline int qpow(int x,int y)
{
int re=1;
for(;y;y>>=1,x=1ll*x*x%mod) if(y&1) re=1ll*re*x%mod;
return re;
}
void getroot(int u,int ff)
{
size[u]=1,mx[u]=0;
for(int i=hd[u];i;i=nex[i])
{
int v=to[i];
if(v==ff||vis[v]) continue;
getroot(v,u);
size[u]+=size[v];
mx[u]=max(mx[u],size[v]);
}
mx[u]=max(mx[u],sn-size[u]);
if(mx[u]<mx[root]) root=u;
}
void dfs(int u,int ff,int x,int y,int v)
{
q[++tot]=node(2*y-x,y-2*x,0,v);
q[++tot]=node(x-2*y,2*x-y,1,v);
for(int i=hd[u];i;i=nex[i])
{
int v=to[i];
if(v==ff||vis[v]) continue;
dfs(v,u,x+(col[i]==0),y+(col[i]==1),1ll*v*val[i]%mod);
}
}
void calc(int u,int flag,int pre,int v)
{
tot=0;
dfs(u,0,pre==0,pre==1,v);
sort(q+1,q+1+tot,cmp);
for(int i=1;i<=tot;++i)
{
if(q[i].op==0)
{
addv.add(q[i].y,1);
mulv.mul(q[i].y,q[i].v);
}
else
{
int a1=mulv.qmul(q[i].y);
int a2=addv.qsum(q[i].y);
int delta=1ll*a1*qpow(q[i].v,a2)%mod;
if(flag==-1) delta=qpow(delta,mod-2);
ans=1ll*ans*delta%mod;
}
}
for(int i=1;i<=tot;++i)
{
if(!q[i].op)
{
addv.add(q[i].y,-1);
mulv.clr(q[i].y);
}
}
}
void solve(int u)
{
calc(u,1,-1,1);
vis[u]=1;
for(int i=hd[u];i;i=nex[i])
{
int v=to[i];
if(vis[v]) continue;
calc(v,-1,col[i],val[i]);
}
for(int i=hd[u];i;i=nex[i])
{
int v=to[i];
if(vis[v]) continue;
sn=size[v],root=0,getroot(v,u),solve(root);
}
}
int main()
{
setIO("input");
memset(mulv.M,1,sizeof(mulv.M));
int i,j,n;
scanf("%d",&n);
for(i=1;i<n;++i)
{
int x,y,z,c;
scanf("%d%d%d%d",&x,&y,&z,&c);
add(x,y,z,c),add(y,x,z,c);
}
sn=mx[0]=n,getroot(1,0),solve(root);
printf("%lld\n",ans);
return 0;
}

浙公网安备 33010602011771号