dtoj#4317. 随机(random)

题目描述:

有一棵$N$个点的树和$M$个操作$x,y,S$,表示给$x$到$y$的链上的节点都加入一个数字字符串$S$。

所有操作都结束后需要对每个点进行一次询问。首先在该节点的所有字符串形成的`trie`上随机选择一个点(此时为步数$0$),这里定义`trie`的根节点为空串,然后每次随机一个与当前点相邻的点移动过去且步数$+1$,如果移动到了某个字符串对应的节点则结束。询问每个点对应`trie`的期望步数。

输出模$998244353$意义下的值。假如答案为$\frac{x}{y}(gcd(x,y)=1)$,那么需要输出$r(0≤r<998244353)$满足$x \equiv y×r(mod~998244353)$。

数据保证答案不会存在$y \equiv 0(mod~998244353)$的情况。

思路:

考虑对于一棵固定的 $trie​$ 树,一个点走到终止点的期望会是
$$
f[x]=\frac{1}{d[x]}\times((\sum_{son}f[to])+f[fa])+1
$$
那么如果我们对于一颗 $trie$ 树 $dfs$ 时,对于孩子节点的情况已经求得了,我们把 $f[fax]$ 看作未知数, $f[x]$ 可以表示成一个关于 $f[fax]$ 的一次函数。
$$
f[x]=A_x\times f[fax]+B_x
$$
考虑如何确定每一个值的 $A_x$ 和 $B_x$ ,访问到 $x$ 对于每一个孩子的情况已经确定,即:
$$
f[to]=A_{to}\times f[x]+B_{to}
$$
那么:
$$
f[x]=\frac{1}{d[x]}\times((\sum_{son}A_{to}f[x]+B_{to})+f[fax])+1
$$
整理一下得到:
$$
f[x]=\frac{1}{d[x]-\sum_{son}A_{to}}f[fax]+\frac{(\sum_{son}B_{to})+d[x]}{d[x]-\sum_{son}A_{to}}
$$
容易知道最后对于一棵 $trie​$ 树的答案是
$$
Ans=\frac{\sum_{i=1}^{cnt} f[x]}{cnt}
$$
( $cnt$ 表示 $trie$ 树的点数 )

所以要考虑维护整个子树的和

我们再用一样的方法表示出子树的和 $g[x]$ :
$$
g[x]=f[x]+\sum_{son}g[to]
$$
同理 $g[x]$ 可以表示成关于 $ f[fax] $ 的一次函数:
$$
g[to]=C_xf[fax]+D_x
$$

$$
g[x]=f[x]+\sum_{son}(C_{x}f[x]+D_{x})
$$

$$
g[x]=A_xf[fax]+Bx+\sum_{son}(C_{to}(A_xf[fax]+B_x)+D_{to})
$$

$$
g[x]=((1+\sum_{son}C_{to})A_{x})f[fax]+(\sum_{son}C_{to}+1)B_x+\sum_{son}D_{to}
$$

对于树上结点的删除与添加在 $trie$ 树上动态修改

以下代码:

#include<bits/stdc++.h>
#define il inline
#define pb push_back
#define LL long long
#define _(d) while(d(isdigit(ch=getchar())))
using namespace std;
const int N=3e5+5,M=2e6+5,p=998244353;
char s[M],c[M];
int n,head[N],ne[N<<1],to[N<<1],cnt,fa[N][21],d[N],res[N],be[N];
int sz[M],rt[N],A[M],B[M],C[M],D[M],ch[M][12],num[M],tag[M],m;
struct node{int x,c;};
vector<node> v[N];
il int read(){
   int x,f=1;char ch;
   _(!)ch=='-'?f=-1:f;x=ch^48;
   _()x=(x<<1)+(x<<3)+(ch^48);
   return f*x;
}
il int mu(int x,int y){
    return x+y>=p?x+y-p:x+y;
}
il int ksm(LL a,int y){
    LL b=1;
    while(y){
        if(y&1)b=b*a%p;
        a=a*a%p;y>>=1;
    }
    return b;
}
il void ins(int x,int y){
    ne[++cnt]=head[x];
    head[x]=cnt;to[cnt]=y;
}
il void dfs1(int x){
    for(int i=1;fa[x][i-1];i++)fa[x][i]=fa[fa[x][i-1]][i-1];
    for(int i=head[x];i;i=ne[i]){
        if(fa[x][0]==to[i])continue;
        fa[to[i]][0]=x;
        d[to[i]]=d[x]+1;dfs1(to[i]);
    }
}
il int Lca(int x,int y){
    if(d[x]<d[y])swap(x,y);
    for(int i=20;i>=0;i--)if(d[fa[x][i]]>=d[y])x=fa[x][i];
    if(x==y)return x;
    for(int i=20;i>=0;i--)if(fa[x][i]!=fa[y][i])x=fa[x][i],y=fa[y][i];
    return fa[x][0];
}
il void update(int x,int f=0){
    int Sa=0,Sb=0,Sc=0,Sd=0,d=0;
    A[x]=B[x]=C[x]=D[x]=0;d=f^1;
    if(!num[x]){sz[x]=0;return;}sz[x]=1;
    for(int i=0,to;i<=9;i++)if(sz[to=ch[x][i]]>0){
        d++;sz[x]+=sz[to];
        Sa=mu(Sa,A[to]);Sb=mu(Sb,B[to]);
        Sc=mu(Sc,C[to]);Sd=mu(Sd,D[to]);
    }
    if(tag[x]){D[x]=Sd;return;}
    A[x]=ksm(mu(d,p-Sa),p-2);B[x]=1ll*A[x]*mu(Sb,d)%p;
    C[x]=1ll*A[x]*mu(Sc,1)%p;D[x]=mu(1ll*B[x]*mu(Sc,1)%p,Sd);
}
il void add(int &x,int id,int l){
    if(!x)x=++cnt;num[x]++;
    if(l>=be[id+1]){tag[x]++;update(x);return;}
    add(ch[x][s[l]-'0'],id,l+1);update(x);
}
il void del(int x,int id,int l){
    num[x]--;
    if(l>=be[id+1]){tag[x]--;update(x);return;}
    del(ch[x][s[l]-'0'],id,l+1);update(x);
}
il void merge(int &x,int y){
    if(!num[x]||!num[y]){x=(num[x]?x:y);return;}
    num[x]+=num[y];tag[x]+=tag[y];
    for(int i=0;i<=9;i++)merge(ch[x][i],ch[y][i]);
    update(x);
}
il int query(int x){
    update(x,1);
    return 1ll*D[x]*ksm(sz[x],p-2)%p;
}
il void dfs(int x,int fa){
    int son=0;
    for(int i=head[x];i;i=ne[i]){
        if(fa^to[i]){
            dfs(to[i],x);son=to[i];
        }
    }
    if(son)rt[x]=rt[son];
    for(int i=head[x];i;i=ne[i])
        if(fa^to[i]&&to[i]^son)merge(rt[x],rt[to[i]]);
    for(int i=0;i<v[x].size();i++){
        node k=v[x][i];
        if(k.c>0)add(rt[x],k.x,be[k.x]);
        else del(rt[x],k.x,be[k.x]);
    }
    res[x]=query(rt[x]);
}
int main()
{
    n=read();
    for(int i=1;i<n;i++){
        int x=read(),y=read();
        ins(x,y);ins(y,x);
    }
    d[1]=1;dfs1(1);m=read();
    int now=1;
    for(int i=1;i<=m;i++){
        int x=read(),y=read();
        scanf(" %s",c+1);
        be[i]=now;int l=strlen(c+1);
        for(int i=1;i<=l;i++)s[now++]=c[i];
        int lca=Lca(x,y);
        if(d[x]>d[y])swap(x,y);
        if(x==lca)v[fa[x][0]].pb((node){i,-1}),v[y].pb((node){i,1});
        else v[x].pb((node){i,1}),v[y].pb((node){i,1}),v[lca].pb((node){i,-1}),v[fa[lca][0]].pb((node){i,-1});
    }
    be[m+1]=now;
    cnt=0;dfs(1,0);
    for(int i=1;i<=n;i++)printf("%d\n",res[i]);
    return 0;
} 
View Code

 

posted @ 2019-03-31 20:43  Jessiejzy  阅读(200)  评论(0编辑  收藏  举报