LOJ#575. 「LibreOJ NOI Round #2」不等关系 容斥+分治NTT
容斥+分治NTT.
令 $dp[i]$ 表示以 $i$ 结尾的方案数.
如果只有小于号的话 $dp[i]$ 是非常好求的:$\frac{n!}{\prod a_{i}}$ 即总阶乘除以每一个小于号连续段.
有大于号的时候考虑容斥:
遇到第一个大于号的时候先不考虑当前位置关系,方案数就是 $dp[j] \times \binom{i}{i-j}$.
那么我们多加了当前位置是小于号的情况,需要在下一次减掉.
遇到第二个大于号的时候也不考虑当前位置关系,减掉 $dp[j] \times \binom{i}{i-j}$,这时将上面多加的减掉了,但是又多减了两个位置都是小于号的方案数.
所以我们就得到了一个容斥式子:$dp[i]=\sum_{j=0}^{i-1} [s_{j}='>'] (-1)^{c[i-1]-c[j]}dp[j] \times \binom{i}{i-j}$
这个式子可以用分治 NTT 优化到 $O(n \log^2 n)$.
code:
#include <vector>
#include <cstdio>
#include <cstring>
#include <algorithm>
#define N 100007
#define ll long long
#define mod 998244353
#define setIO(s) freopen(s".in","r",stdin)
using namespace std;
char str[N];
int A[N<<2],B[N<<2],bu[N];
int c[N],fac[N],inv[N],dp[N],g[N],n;
int qpow(int x,int y) {
int tmp=1;
for(;y;y>>=1,x=(ll)x*x%mod) {
if(y&1) tmp=(ll)tmp*x%mod;
}
return tmp;
}
int get_inv(int x) {
return qpow(x,mod-2);
}
void NTT(int *a,int len,int op) {
for(int i=0,k=0;i<len;++i) {
if(i>k) swap(a[i],a[k]);
for(int j=len>>1;(k^=j)<j;j>>=1);
}
for(int l=1;l<len;l<<=1) {
int wn=qpow(3,(mod-1)/(l<<1));
if(op==-1) {
wn=get_inv(wn);
}
for(int i=0;i<len;i+=l<<1) {
int w=1,x,y;
for(int j=0;j<l;++j) {
x=a[i+j],y=(ll)w*a[i+j+l]%mod;
a[i+j]=(ll)(x+y)%mod;
a[i+j+l]=(ll)(x-y+mod)%mod;
w=(ll)w*wn%mod;
}
}
}
if(op==-1) {
int iv=get_inv(len);
for(int i=0;i<len;++i) {
a[i]=(ll)a[i]*iv%mod;
}
}
}
void init() {
fac[0]=inv[1]=1;
for(int i=1;i<N;++i) {
fac[i]=(ll)fac[i-1]*i%mod;
}
for(int i=2;i<N;++i) {
inv[i]=(ll)(mod-mod/i)*inv[mod%i]%mod;
}
inv[0]=1;
for(int i=1;i<N;++i) inv[i]=(ll)inv[i-1]*inv[i]%mod;
}
void solve(int l,int r) {
if(l==r) {
return;
}
int mid=(l+r)>>1,s1=0,s2=0,lim;
solve(l,mid);
for(int i=l;i<=mid;++i) {
if(str[i]=='<') A[s1++]=0;
else A[s1++]=(ll)dp[i]*bu[i+1]%mod;
}
for(int i=0;i<=r-l;++i) {
B[s2++]=g[i];
}
for(lim=1;lim<(s1+s2);lim<<=1);
for(int i=s1;i<lim;++i) A[i]=0;
for(int i=s2;i<lim;++i) B[i]=0;
NTT(A,lim,1),NTT(B,lim,1);
for(int i=0;i<lim;++i) {
A[i]=(ll)A[i]*B[i]%mod;
}
NTT(A,lim,-1);
for(int i=mid+1;i<=r;++i) {
(dp[i]+=(ll)bu[i]*A[i-l]%mod)%=mod;
}
for(int i=0;i<lim;++i) A[i]=B[i]=0;
solve(mid+1,r);
}
int main() {
// setIO("input");
init();
scanf("%s",str+1);
n=strlen(str+1)+1;
for(int i=1;i<n;++i) {
c[i]=c[i-1]+(str[i]=='>');
}
for(int i=1;i<=n;++i) {
if(c[i-1]&1) bu[i]=mod-1;
else bu[i]=1;
}
dp[0]=1;
for(int i=1;i<=n;++i)
g[i]=inv[i];
solve(0,n);
printf("%d\n",(ll)dp[n]*fac[n]%mod);
return 0;
}

浙公网安备 33010602011771号