洛谷 P5326 [ZJOI2019]开关

洛谷 P5326 [ZJOI2019]开关

https://www.luogu.com.cn/problem/P5326

Snipaste_2020-06-30_18-18-35.png

Snipaste_2020-06-30_18-18-29.png

Snipaste_2020-06-30_18-18-44.png

Tutorial

https://www.luogu.com.cn/blog/xht37/solution-p5326

https://www.cnblogs.com/PinkRabbit/p/ZJOI2019D2T1.html

\(p_i=\dfrac {p_i}{\sum p_i}\)

\(f(x)\)表示在第\(k\)步到达合法状态的概率的生成函数,因为只关心第一次到达合法状态的情况,所以设\(g(x)\)表示走\(k\)步后回到原来的状态的概率,\(h(x)\)表示第\(k\)步第一次走到合法状态的概率,则有\(f(x)=g(x)h(x) \to h(x)=\dfrac{f(x)}{g(x)}\) .设\(h(x)=\sum a_k x^k\),则我们要求就是

\[\sum ka_k=h'(1)=\dfrac{f'(1)g(1)-f(1)g'(1)}{g^2(1)} \]

考虑如何求\(f(x)\).到达合法状态的条件为选择开关\(i\)的次数与\(s_i\)相等.则有

\[F_i(x)=\dfrac{e^{p_ix}+(-1)^{s_i}e^{-p_ix}}2 \]

发现\(f(x)\)是OGF,\(F_i(x)\)为EGF,为了相互转化,将\(\prod F_i(x)\)表示为\(\sum c_k(e^x)^k\)的形式,其中\(c_k\)可以用背包在\(O(n\sum p)\)的时间求得,最后得到

\[\begin{align} f(x)&=\sum_k ([x^k]k!\sum_i c_i(e^x)^i)x^k \\ &=\sum_k(k!\sum_i c_i [x^k](e^x)^i)x^k \\ &=\sum_k(k!\sum_ic_i\dfrac{i^k}{k!})x^k \\ &=\sum_k (\sum_i c_ii^k)x^k \\ &=\sum_ic_i\sum_{k}i^kx^k \\ &=\sum_i\dfrac{c_i}{1-ix} \end{align} \]

\(g(x)\)的处理类似,最后得到

\[g(x)=\sum_i\dfrac{d_i}{1-ix} \]

但是发现当\(i=1\)时会有\(1-x\)这一项,所以不能直接将\(x=1\)带入,考虑分子分母同乘\((1-x)\),得到新的\(f(x),g(x)\)

\[f(x)=\sum_i\dfrac{c_i(1-x)}{1-ix}=c_1+\sum_{i\not=1}\dfrac{c_i(1-x)}{1-ix} \]

所以此时\(f(1)=c_1\)

\[f'(x)=\sum_{i\not=1}\dfrac{c_i(ix-1)+ic_i(1-x)}{(1-ix)^2} \\ f'(1)=\sum_{i\not=1}\dfrac{c_i(i-1)}{(1-i)^2}=\sum_{i\not=1}\dfrac{c_i}{i-1} \]

\(g(1),g'(1)\)也类似计算,即可得到答案.

Code

#include <cstdio>
#include <cstring>
#include <iostream>
#define debug(...) fprintf(stderr,__VA_ARGS__)
#define inver(a) power(a,mod-2)
using namespace std;
inline char gc() {
//	return getchar();
	static char buf[100000],*l=buf,*r=buf;
	return l==r&&(r=(l=buf)+fread(buf,1,100000,stdin),l==r)?EOF:*l++;
}
template<class T> void rd(T &x) {
	x=0; int f=1,ch=gc();
	while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=gc();}
	while(ch>='0'&&ch<='9'){x=x*10-'0'+ch;ch=gc();}
	x*=f;
}
typedef long long ll;
const int mod=998244353,r2=(mod+1)>>1;
const int maxn=100+5,maxP=1e5+50;
int n,P,s[maxn],p[maxn];
int c[2][maxP],d[2][maxP];
inline int sub(int x) {return x<0?x+mod:x;}
ll power(ll x,ll y) {
	ll re=1;
	while(y) {
		if(y&1) re=re*x%mod;
		x=x*x%mod;
		y>>=1;
	}
	return re;
}
inline int sqr(int x) {return (ll)x*x%mod;}
inline void upd(int *a,int *b,int v,int w) {
	for(int i=0;i<=(P<<1);++i) if(b[i]) {
		a[i+w]=(a[i+w]+(ll)v*b[i])%mod;
	}
}
int main() {
	rd(n);
	for(int i=1;i<=n;++i) rd(s[i]);
	for(int i=1;i<=n;++i) rd(p[i]),P+=p[i];
	int cur=0;
	c[cur][P]=d[cur][P]=1;
	for(int i=1;i<=n;++i) {
		cur^=1;
		memset(c[cur],0,sizeof(c[cur])),memset(d[cur],0,sizeof(d[cur]));
		upd(c[cur],c[cur^1],r2,p[i]),upd(c[cur],c[cur^1],(ll)r2*(s[i]==1?mod-1:1)%mod,-p[i]);
		upd(d[cur],d[cur^1],r2,p[i]),upd(d[cur],d[cur^1],r2,-p[i]);
	}
	int an=0,c1=c[cur][P<<1],d1=d[cur][P<<1],t=inver(P);
	for(int i=-P;i<P;++i) {
		an=(an+inver(sub((ll)i*t%mod-1))*sub((ll)c[cur][i+P]*d1%mod-(ll)c1*d[cur][i+P]%mod))%mod;
	}
	an=(ll)an*sqr(inver(d1))%mod;
	printf("%d\n",an);
	return 0;
}
posted @ 2020-06-30 19:04  LJZ_C  阅读(166)  评论(0编辑  收藏  举报