Alpha

img

img

题目大意

给定初始状态全部为 \(0\) 的长度为 \(10^9\) 的区间,$ n $ 次操作,每次操作有 \(p\) 的概率给 \([l,r]\) 的数 \(+1\) 求最终序列中 \(K\) 的期望出现次数。

题解

首先将区间离散化,将离散化后的每一个小区间看做一个有权值的点,建线段树,然后考虑对每一个区间维护一个多项式,\(\sum a_ix^i\) 表示区间内 \(i\) 的出现次数,那么最终每一个位置的多项式为所有覆盖它的操作的 \(1-p+px\) 之积。每一步操作在线段树对应节点打标记,最终每一个位置的多项式就是其权值乘上所有其线段树父节点的多项式的结果。那么考虑每个线段树节点由若干个形如 \(px+(1-p)\) 的多项式相乘而得,由于这样的所有节点的一次多项式的数量之和是 \(n\log n\) 级别的,那么对于每一个线段树节点分别做一次分治 \(NTT\) 即可得到该节点的多项式,而多项式次数之和不超过 \(n\log n\) ,这部分时间复杂度 \(O(n\log ^3 n)\)

但是还有问题,如果将父节点的多项式乘给儿子节点复杂度会达到 \(O(n^2\log^2 n)\) 级别,但是由于只需要求所有叶子节点多项式某一项之和,可以直接反过来将儿子节点的多项式相加再乘以父节点的多项式。

这样每一个节点的多项式最多会被乘以它线段树祖先个数遍( \(O(\log n)\) 量级)所有多项式次数之和是 \(O(n\log n)\) 级,再乘 \(NTT\)\(\log\) 复杂度,那么最终就是 \(O(n\log ^3 n)\)

#include<bits/stdc++.h>
#define debug(x) cerr<<#x<<" = "<<x
#define sp <<"  "
#define el <<endl
#define LL long long
#define M 800020
#define inv3 332748118
#define mod 998244353
using namespace std;
namespace IO{
    const int BS=(1<<23)+5; int Top=0;
    char Buffer[BS],OT[BS],*OS=OT,*HD,*TL,SS[20]; const char *fin=OT+BS-1;
    char Getchar(){if(HD==TL){TL=(HD=Buffer)+fread(Buffer,1,BS,stdin);} return (HD==TL)?EOF:*HD++;}
    void flush(){fwrite(OT,1,OS-OT,stdout),OS=OT;}
    void Putchar(char c){*OS++ =c;if(OS==fin)flush();}
    void write(int x){
        if(!x){Putchar('0');return;} if(x<0) x=-x,Putchar('-');
        while(x) SS[++Top]=x%10,x/=10;
        while(Top) Putchar(SS[Top]+'0'),--Top;
    }
    int read(){
        int nm=0,fh=1; char cw=Getchar();
        for(;!isdigit(cw);cw=Getchar()) if(cw=='-') fh=-fh;
        for(;isdigit(cw);cw=Getchar()) nm=nm*10+(cw-'0');
        return nm*fh;
    }
} using namespace IO;
inline int mul(int x,int y){return (LL)x*(LL)y%mod;}
inline int add(int x,int y){x+=y; return (x>=mod)?(x-mod):x;}
inline int mns(int x,int y){x-=y; return (x<0)?(x+mod):x;}
int qpow(int x,int sq){
	int res=1;
	for(;sq;sq>>=1,x=mul(x,x)) if(sq&1) res=mul(res,x);
	return res;
}
int od[M],cod[M],tot,n,m,A[M],B[M],G[M];
int L[M],R[M],P[M],tl[M],tr[M];
int X[40][M],S[M*41],K,T[M*41];
void NTT(int *x,int len,int kd){
	for(int i=1;i<len;i++) if(i<od[i]) swap(x[i],x[od[i]]);
	for(int gg,tt=1;tt<len;tt<<=1){
		gg=qpow((kd>0)?3:inv3,(mod-1)/(tt<<1));
		for(int now=1,st=0;st<len;st+=(tt<<1),now=1){
			for(int pos=st;pos<st+tt;pos++,now=mul(gg,now)){
				int t1=x[pos],t2=mul(now,x[pos+tt]);
				x[pos]=add(t1,t2),x[pos+tt]=mns(t1,t2);
			}
		}
	}
	if(kd<0) for(int iv=qpow(len,mod-2),i=0;i<len;i++) x[i]=mul(x[i],iv);
}
vector<int> id[M];
void mdf(int x,int l,int r,int ls,int rs,int ID){
	if(ls<=tl[l]&&tr[r]<=rs){id[x].push_back(ID);return;}
	if(rs<tl[l]||tr[r]<ls) return; int mid=((l+r)>>1);
	mdf(x<<1,l,mid,ls,rs,ID),mdf(x<<1|1,mid+1,r,ls,rs,ID);
}
void tms(int *x,int *a,int la,int *b,int lb){
	for(int i=0;i<=la;i++) A[i]=a[i];
	for(int i=0;i<=lb;i++) B[i]=b[i];
	if((la<=25||lb<=25)){
		for(int i=0;i<=la+lb;i++) x[i]=0;
		for(int i=0;i<=la;i++) for(int j=0;j<=lb;j++)
			x[i+j]=add(x[i+j],mul(A[i],B[j])); return;
	} int len=1,nw=-1;
	while(len<=la+lb) len<<=1,nw++;
	for(int i=0;i<len;i++) od[i]=((od[i>>1]>>1)|((i&1)<<nw));
	for(int i=la+1;i<len;i++) A[i]=0; NTT(A,len,1);
	for(int i=lb+1;i<len;i++) B[i]=0; NTT(B,len,1);
	for(int i=0;i<len;i++) G[i]=mul(A[i],B[i]); NTT(G,len,-1);
	for(int i=0;i<=la+lb;i++) x[i]=G[i];
}
void mult(int *x,int l,int r,int k){
	if(l==r){x[1]=P[id[k][l]],x[0]=mns(1,P[id[k][l]]);return;}
	int mid=((l+r)>>1),ls,rs; ls=mid-l+1,rs=r-mid;
	mult(x,l,mid,k),mult(x+ls+2,mid+1,r,k),tms(x,x,ls,x+ls+2,rs);
}
int solve(int x,int l,int r,int kd){
	int mid=((l+r)>>1),ls,rs,len;
	if(l<r) ls=solve(x<<1,l,mid,kd),rs=solve(x<<1|1,mid+1,r,kd+1);
	if(id[x].size()) mult(T,0,id[x].size()-1,x); else T[0]=1;
	if(l==r){
		S[0]=(tr[r]-tl[l]+1)%mod;
		tms(X[kd],S,0,T,id[x].size());
		return id[x].size();
	} len=max(ls,rs);
	for(int i=ls+1;i<=rs;i++) X[kd][i]=0;
	for(int i=rs+1;i<=ls;i++) X[kd+1][i]=0;
	for(int i=0;i<=len;i++) X[kd][i]=add(X[kd][i],X[kd+1][i]);
	tms(X[kd],T,id[x].size(),X[kd],len);
	return id[x].size()+len;
}
int main(){
	n=read(),cod[++tot]=1,cod[++tot]=1000000001;
	for(int i=1;i<=n;i++){
		L[i]=read(),R[i]=read(),P[i]=read();
		cod[++tot]=L[i],cod[++tot]=R[i]+1;
	} sort(cod+1,cod+tot+1),K=read();
	for(int i=2;i<=tot;i++) if(cod[i]>cod[i-1]) tl[++m]=cod[i-1],tr[m]=cod[i]-1;
	for(int i=1;i<=n;i++) mdf(1,1,m,L[i],R[i],i);
	solve(1,1,m,0); printf("%d\n",X[0][K]); return 0;
}
posted @ 2018-12-10 19:32  OYJason  阅读(341)  评论(2编辑  收藏  举报