LOJ 3312「ZJOI2020」传统艺能

每个点显然是独立的,我们对于每个点分开考虑 \(k\) 次操作后有标记的概率。

\(f_i\) 表示 \(i\) 次操作后该点存在标记的方案数,\(g_i\) 表示 \(i\) 次操作后该点到根路径上的点均不存在标记的方案数。

\(t=(\frac {n(n+1)}{2})^{i-1}\)。将点分类:

  • 终点。\(f_i+=t\)。该点不论前面如何操作,这次操作后必定有标记。
  • 途经点(不含终点)。\(g_i+=t\)。该点到根的路径上的标记全部被下放,操作后路径上不存在标记。
  • 终点子树(不含终点)。\(f_i+=f_{i-1}\)。该操作不会影响该点的标记情况,但是操作后到根的路径上一定存在标记。
  • 途经点儿子。\(f_i+=t-g_{i-1},g_i+=g_{i-1}\)。如果在操作之前该点到根的路径上有标记,则操作后一点会被下放到这个点上。
  • 途经点儿子子树。\(f_i+=f_{i-1},g_i+=g_{i-1}\)。这种操作对该点状态没有什么影响。

接下来只需要计算这五类点的数量即可,设它们分别为 \(a,b,c,d,e\)

发现 \(a\)\(d\) 比较好算。剩下的都可以通过 \(a\)\(d\) 推出。

  • 计算 \(a\) 只需要考虑包含该点并且不包含它的父亲即可。

  • 计算 \(d\) 只需要考虑经过它的兄弟,并且两个端点在该结点的同侧即可。

  • \(b[fa[x]]=d[x]+a[x]+b[x]\)。如果经过左儿子,那么左儿子就是终点或者途经点。如果不经过左儿子,那么左儿子就是途经点儿子。

  • \(c[x]=a[fa[x]]+c[fa[x]]\)。只需要考虑 \(fa[x]\) 是不是终点即可。

  • \(e[x]=d[fa[x]]+e[fa[x]]\)。只需要考虑 \(fa[x]\) 是不是途经点儿子即可。

这样 \(f_n\) 就可以用矩阵快速幂计算了。时间复杂度 \(O(n\log k)\)

代码

#include <bits/stdc++.h>
using namespace std;
#define N 400010
#define Re register
#define Mod 998244353
inline int read() {
    int x=0;
    char ch=getchar();
    while (!isdigit(ch)) ch=getchar();
    while (isdigit(ch)) x=x*10+ch-'0',ch=getchar();
    return x;
}
struct Matrix {
    int a[3][3];
    inline Matrix() {memset(a,0,sizeof(a));}
    inline friend Matrix operator * (const Matrix A,const Matrix B) {
        Matrix C;
        for (Re int i=0;i<3;i++)
            for (Re int j=0;j<3;j++)
                for (Re int k=0;k<3;k++)
                    (C.a[i][j]+=1LL*A.a[i][k]*B.a[k][j]%Mod)%=Mod;
        return C;
    }
}Bas,Tr;
inline Matrix Pow(Matrix A,int b) {
    Matrix res;
    res.a[0][0]=res.a[1][1]=res.a[2][2]=1;
    for (;b;b>>=1,A=A*A) if (b&1) res=res*A;
    return res;
}
inline int Pow(int a,int b,int p=Mod) {
    int res=1;
    for (;b;b>>=1,a=1LL*a*a%p)
        if (b&1) res=1LL*res*a%p;
    return res;
}
int cnt=0;
int L[N],R[N],ls[N],rs[N],fa[N],len[N],bro[N];
inline int build(int l,int r) {
    int k=++cnt;
    L[k]=l,R[k]=r,len[k]=r-l+1;
    if (l==r) return k;
    int mid=read();
    ls[k]=build(l,mid);
    rs[k]=build(mid+1,r);
    fa[ls[k]]=fa[rs[k]]=k;
    bro[ls[k]]=rs[k],bro[rs[k]]=ls[k];
    return k;
}
int lp[N],rp[N],A[N],B[N];
int c1[N],c2[N],c3[N],c4[N],c5[N];
inline void dfs1(int x,int n) {
    if (!x) return;
    if (x==ls[fa[x]]) (c1[x]+=1LL*len[bro[x]]*L[x]%Mod)%=Mod,c4[x]=1LL*(n+n-L[bro[x]]-R[bro[x]]+2)*len[bro[x]]/2%Mod;
    if (x==rs[fa[x]]) (c1[x]+=1LL*len[bro[x]]*(n-R[x]+1)%Mod)%=Mod,c4[x]=1LL*(L[bro[x]]+R[bro[x]])*len[bro[x]]/2%Mod;
    dfs1(ls[x],n),dfs1(rs[x],n);
}
inline void dfs3(int x) {
    if (!x) return;
    c3[x]=(c1[fa[x]]+c3[fa[x]])%Mod;
    c5[x]=(c5[fa[x]]+c4[fa[x]])%Mod;
    dfs3(ls[x]),dfs3(rs[x]);
    c2[fa[x]]=(1LL*c4[x]+c1[x]+c2[x])%Mod;
}
int main() {
    int n=read(),k=read();
    build(1,n),c1[1]=1,dfs1(1,n);
    Bas.a[0][1]=Bas.a[0][2]=1;
    A[0]=B[0]=c2[0]=0,dfs3(1);
    int All=1LL*n*(n+1)/2%Mod,ans=0;
    for (int i=1;i<=cnt;i++) {
        Tr.a[0][0]=(c3[i]+c5[i])%Mod;
        Tr.a[1][0]=(Mod-c4[i])%Mod;
        Tr.a[1][1]=(c4[i]+c5[i])%Mod;
        Tr.a[2][0]=(c1[i]+c4[i])%Mod;
        Tr.a[2][1]=c2[i];
        Tr.a[2][2]=All;
        (ans+=(Bas*Pow(Tr,k)).a[0][0])%=Mod;
    }
    printf("%d\n",1LL*ans*Pow(Pow(All,k),Mod-2)%Mod);
    return 0;
}
posted @ 2020-10-09 07:20  bo1949  阅读(138)  评论(0)    收藏  举报