随机爬树题解
随机爬树题解
\(n^2\) 暴力:
思路:
- 求期望,即求所有点的权值乘上概率后的和,即:
\[ans=\sum_{u \in V}{P_u a_u}
\]
-
求每个点的概率 \(P_u\) :
- 由题,令走到父亲的概率为 \(P_f\),走到儿子 \(s\) 的概率则为 \(P_f \times \frac{w_s}{sum_f}\)(其中 \(sum_f\) 为 \(f\) 所有儿子的 \(w\) 之和)。
-
统计答案:
- 记 \(ans_u\) 表示 \(u\) 子树(不含 \(u\) 本身)的答案之和,最终答案为 \(ans_1+a_1\)。
- 暴力修改,跑 DFS 暴力求和即可。
代码:
//n^2暴力 60pts
#include<iostream>
#include<cstdio>
#include<cstring>
#include<vector>
using namespace std;
typedef long long ll;
const int N=1e5+5,Mod=998244353;
int n,q,fa[N];
ll sum[N],inv[N],w[N],a[N];
ll p[N],ans[N];
vector <int> e[N];
ll qpow(ll a,ll b){
ll res=1;
while(b){
if(b&1) (res*=a)%=Mod;
(a*=a)%=Mod;
b>>=1;
}
return res%Mod;
}
void dfs(int u){
ans[u]=0;
inv[u]=qpow(sum[u],Mod-2);
for(int i=0;i<e[u].size();i++){
int v=e[u][i];
p[v]=p[u]*w[v]%Mod*inv[u]%Mod;
dfs(v);
(ans[u]+=(ans[v]+p[v]*a[v]%Mod)%Mod)%=Mod;
}
}
int main(){
scanf("%d",&n);
for(int i=2;i<=n;i++){
scanf("%d",&fa[i]);
e[fa[i]].push_back(i);
}
for(int i=1;i<=n;i++){
scanf("%lld",&w[i]);
(sum[fa[i]]+=w[i])%=Mod;
}
for(int i=1;i<=n;i++){
scanf("%lld",&a[i]);
}
p[1]=1;
dfs(1);
printf("%lld\n",(ans[1]+a[1])%Mod);
scanf("%d",&q);
int u;
ll ww,aa;
for(int i=1;i<=q;i++){
scanf("%d",&u);
sum[fa[u]]=(sum[fa[u]]-w[u]+Mod)%Mod;
scanf("%lld%lld",&w[u],&a[u]);
sum[fa[u]]=(sum[fa[u]]+w[u])%Mod;
dfs(1);
printf("%lld\n",(ans[1]+a[1])%Mod);
}
return 0;
}
优化后正解:
思路:
-
每次修改 \(u\) 只对 \(f\) 的整棵子树产生影响,故用线段树维护子树和。
-
考虑有哪些影响:
-
\(w_u\) 修改为 \(ww\),使得 \(sum_f\) 发生改变,故整棵子树的概率都会变化:
-
对于 \(f\) 子树中每个节点 \(t\)(不含 \(f\)),原概率为 \(P_t=P_{fa_t}\times \frac{w_t}{sum_{fa_t}}\),修改后变为 \(P_t'=P_{fa_t}\times \frac{w_t}{sum_{fa_t}-w_u+ww}\)。
\[P_t'=P_t\times\frac{sum_{fa_t}}{sum_{fa_t}-w_u+ww} \]则有 \(\Delta P=\frac{sum_{fa_t}}{sum_{fa_t}-w_u+ww}\)。
-
对于点 \(u\),原概率为 \(P_u=P_f\times \frac{w_u}{sum_f}\),修改后变为 \(P_u'=P_f \times \frac{ww}{sum_f-w_u+ww}\)。
\[P_u'=P_u \times \frac{ww}{wu} \times \Delta P \]则有 \(\Delta w=\frac{ww}{wu}\)。
-
对于 \(u\) 子树中的每个点 \(t\)(不含 \(u\)),原概率为 \(P_t=P_{fa_t}\times \frac{w_t}{sum_{fa_t}}\),修改后变为 \(P_t'=P_{fa_t}\times \Delta w \times \frac{w_t}{sum_{fa_t}}\)。
-
-
\(a_u\) 修改为 \(aa\),只对 \(u\) 的贡献产生影响,原贡献为 \(ans=P_u \times a_u\),修改后变为 \(ans'=P_u \times aa\),即:
\[ans'=ans \times \frac{aa}{au} \]
-
-
综上,变化有:
- \(P_t'=P_t\times\frac{sum_{fa_t}}{sum_{fa_t}-w_u+ww}\);
- \(P_u'=P_u \times \frac{ww}{wu} \times \frac{sum_{fa_t}}{sum_{fa_t}-w_u+ww}\);
- \(P_t'=P_{fa_t}\times \frac{ww}{wu} \times \frac{w_t}{sum_{fa_t}}\);
- \(ans'=ans \times \frac{aa}{au}\)。
其中操作 \(1\)、\(2\) 与 \(2\)、\(3\) 可以合并。
代码:
#include<iostream>
#include<cstdio>
#include<vector>
using namespace std;
typedef long long ll;
const int N=1e5+5,Mod=998244353;
int n,q,fa[N];
ll sum[N],w[N],a[N],p[N],val[N];
int tid,dfn[N],siz[N];
vector <int> e[N];
struct Tree{
ll sum,tag;
}tr[N<<2];
ll qpow(ll a,ll b){
ll res=1;
while(b){
if(b&1) (res*=a)%=Mod;
(a*=a)%=Mod;
b>>=1;
}
return res%Mod;
}
ll inv(ll x){
return qpow(x,Mod-2);
}
void dfs(int u){
dfn[u]=++tid;
val[tid]=p[u]*a[u]%Mod;
siz[u]=1;
for(int i=0;i<e[u].size();i++){
int v=e[u][i];
p[v]=p[u]*w[v]%Mod*inv(sum[u])%Mod;
dfs(v);
siz[u]+=siz[v];
}
}
void update(int k){
tr[k].sum=(tr[k<<1].sum+tr[k<<1|1].sum)%Mod;
}
void pushdown(int k,int l,int r){
tr[k<<1].sum=tr[k<<1].sum*tr[k].tag%Mod;
tr[k<<1].tag=tr[k<<1].tag*tr[k].tag%Mod;
tr[k<<1|1].sum=tr[k<<1|1].sum*tr[k].tag%Mod;
tr[k<<1|1].tag=tr[k<<1|1].tag*tr[k].tag%Mod;
tr[k].tag=1;
}
void build(int k,int l,int r){
tr[k].sum=0;
tr[k].tag=1;
if(l==r){
tr[k].sum=val[l]%Mod;
return ;
}
int mid=(l+r)>>1;
build(k<<1,l,mid);
build(k<<1|1,mid+1,r);
update(k);
}
void modify(int k,int l,int r,int x,int y,int v){
if(x<=l&&r<=y){
tr[k].sum=tr[k].sum*v%Mod;
tr[k].tag=tr[k].tag*v%Mod;
return ;
}
pushdown(k,l,r);
int mid=(l+r)>>1;
if(x<=mid) modify(k<<1,l,mid,x,y,v);
if(mid<y) modify(k<<1|1,mid+1,r,x,y,v);
update(k);
}
int main(){
//freopen("climb2.in","r",stdin);
//freopen("climb.out","w",stdout);
scanf("%d",&n);
for(int i=2;i<=n;i++){
scanf("%d",&fa[i]);
e[fa[i]].push_back(i);
}
for(int i=1;i<=n;i++){
scanf("%lld",&w[i]);
sum[fa[i]]=(sum[fa[i]]+w[i])%Mod;
}
for(int i=1;i<=n;i++){
scanf("%lld",&a[i]);
}
p[1]=1;
dfs(1);
build(1,1,n);
printf("%lld\n",tr[1].sum%Mod);
scanf("%d",&q);
int u;
ll ww,aa;
for(int i=1;i<=q;i++){
scanf("%d%lld%lld",&u,&ww,&aa);
if(fa[u]){
/*
1.f子树(除f本身):Pt=Pf*(wt/sumf)* (1/(sumf-wu+ww))*sumf
2.u:pu=pf*(wu/sumf)*(ww/wu)* (1/(sumf-wu+ww))*sumf
3.u子树(除u本身):Pt=Pft* (ww/wu) *(wt/sumft)
*/
//1. 2. 修改f子树(除f本身)
modify(1,1,n,dfn[fa[u]]+1,dfn[fa[u]]+siz[fa[u]]-1,inv(((sum[fa[u]]-w[u]+ww)%Mod+Mod)%Mod)%Mod*sum[fa[u]]%Mod);
//2. 3.
modify(1,1,n,dfn[u],dfn[u]+siz[u]-1,ww*inv(w[u])%Mod);
sum[fa[u]]=((sum[fa[u]]-w[u]+ww)%Mod+Mod)%Mod;
}
w[u]=ww;
//修改au:Pu*au* (aa/au)
modify(1,1,n,dfn[u],dfn[u],aa*inv(a[u])%Mod);
a[u]=aa;
printf("%lld\n",tr[1].sum%Mod);
}
return 0;
}

浙公网安备 33010602011771号