loading

P8885 「JEOI-R1」子序列

我怎么这么不会做题?????

题意

给定一个字符串 \(s\),字符集为 0,1,?。有 \(q\) 次询问,每次询问给出 \(l,r\),求有多少种将 \([l,r]\) 内的 ? 替换成 0,1 的方案使得 有奇数个可空本质不同子序列 的 非空子串数 为奇数,模数 \(998244353\)

\(n\le 5\times10^4,q\le 3\times10^5\)

分析

关于本质不同子序列有一个经典的 dp 是设 \(f_{i,j}\) 表示前 \(i\) 个位置末尾为 \(j\) 的方案数,转移方程为 \(f_{i,a_i}\leftarrow \sum_j f_{i-1,j},f_{i,j}(j\neq a_i)=f_{i-1,j}\),我们只在乎它的奇偶性,所以只需要求出模 2 结果即可。考虑 dp 套 dp 并把这个 dp 当做内层 dp,设 \(g_{i,S,k}\) 表示前 \(i\) 个点,\(S\) 存储不同 \(f\) 状态下的奇偶性,\(k\) 为合法的非空子串个数的奇偶性,注意由于此时 \(f_0,f_1\) 中不可能同时为 \(1\),所以有效的状态只有三种,所以 \(g\) 总共只有 \(k=16\) 种状态,并且转移是线性关系,可以没那么简单的写出转移矩阵,朴素线段树 \(O(n+q\log nk^3)\),查询时使用向量乘矩阵是 \(O(nk^3+q\log nk^2)\)。线段树做法已经没有优化空间了,考虑猫树分治,复杂度 \(O(n\log nk^3+qk^2)\),仍然过不去,然而猫树分治预处理矩阵乘法的一方为转移矩阵本身,而转移矩阵的非零点数为 \(O(k)\),矩阵乘法可以优化到 \(O(k^2)\),所以最终复杂度 \(O((n\log n+q)k^2)\)

我还是太不会优化了。

点击查看代码
#include<iostream>
#include<cstdio>
#include<cstring>
#include<string>
#include<algorithm>
#include<cmath>
#include<map>
#include<unordered_map>
#include<vector>
#include<queue>
#include<stack>
#include<bitset>
#include<set>
#include<array>
#include<tuple>
#include<ctime>
#include<random>
#include<cassert>
#include<chrono>
#define x1 xx1
#define y1 yy1
#define IOS ios::sync_with_stdio(false)
#define ITIE cin.tie(0)
#define OTIE cout.tie(0)
#define PY puts("Yes")
#define PN puts("No")
#define PW puts("-1")
#define P0 puts("0")
#define P__ puts("")
#define PU puts("--------------------")
#define mp make_pair
#define fi first
#define se second
#define gc getchar
#define pc putchar
#define pb emplace_back
#define un using namespace
#define il inline
#define all(x) x.begin(),x.end()
#define mem(x,y) memset(x,y,sizeof x)
#define popc __builtin_popcountll
#define rep(a,b,c) for(int a=(b);a<=(c);++a)
#define per(a,b,c) for(int a=(b);a>=(c);--a)
#define reprange(a,b,c,d) for(int a=(b);a<=(c);a+=(d))
#define perrange(a,b,c,d) for(int a=(b);a>=(c);a-=(d))
#define graph(i,j,k,l) for(int i=k[j];i;i=l[i].nxt)
#define lowbit(x) ((x)&-(x))
#define lson(x) ((x)<<1)
#define rson(x) ((x)<<1|1)
//#define double long double
//#define int long long
//#define int __int128
using namespace std;
using i64=long long;
using u64=unsigned long long;
using pii=pair<int,int>;
template<typename T1,typename T2>inline bool ckmx(T1 &x,T2 y){return x>=y?0:(x=y,1);}
template<typename T1,typename T2>inline bool ckmn(T1 &x,T2 y){return x<=y?0:(x=y,1);}
inline auto rd(){
	int qwqx=0,qwqf=1;char ch=getchar();
	while(ch<'0'||ch>'9'){if(ch=='-')qwqf=-1;ch=getchar();}
	while(ch>='0'&&ch<='9'){qwqx=(qwqx<<1)+(qwqx<<3)+ch-48;ch=getchar();}return qwqx*qwqf;
}
template<typename T>inline void write(T qwqx,char ch='\n'){
	if(qwqx<0){qwqx=-qwqx;putchar('-');}
	int qwqy=0;static char qwqz[40];
	while(qwqx||!qwqy){qwqz[qwqy++]=qwqx%10+48;qwqx/=10;}
	while(qwqy--){putchar(qwqz[qwqy]);}if(ch)putchar(ch);
}
bool Mbg;
const int mod=998244353;
template<typename T1,typename T2>inline void adder(T1 &x,T2 y){x+=y,x=x>=mod?x-mod:x;}
template<typename T1,typename T2>inline void suber(T1 &x,T2 y){x-=y,x=x<0?x+mod:x;}
const int maxn=5e4+5,maxq=3e5+5,k=16,inf=0x3f3f3f3f;
const long long llinf=0x3f3f3f3f3f3f3f3f;
int n,Q,a[maxn];
char s[maxn];
struct Matrix{
	int mat[k][k];
	Matrix(){mem(mat,0);}
} Z[3],B[maxn];
struct Vector{
	int mat[k];
	Vector(){mem(mat,0);}
	il int getval(){
		int res=0;
		rep(i,0,k-1)if(i&1)adder(res,mat[i]);
		return res;
	}
} A;
Matrix operator+(Matrix x,Matrix y){
	rep(i,0,k-1)rep(j,0,k-1)adder(x.mat[i][j],y.mat[i][j]);
	return x;
}
Matrix operator*(Matrix x,Matrix y){
	Matrix res;
	rep(i,0,k-1)rep(l,0,k-1)if(x.mat[i][l])rep(j,0,k-1)adder(res.mat[i][j],1ll*x.mat[i][l]*y.mat[l][j]%mod);
	return res;
}
Matrix operator^(Matrix x,Matrix y){
	Matrix res;
	rep(l,0,k-1)rep(j,0,k-1)if(y.mat[l][j])rep(i,0,k-1)adder(res.mat[i][j],1ll*x.mat[i][l]*y.mat[l][j]%mod);
	return res;
}
Vector operator*(Vector x,Matrix y){
	Vector res;
	rep(i,0,k-1)rep(j,0,k-1)adder(res.mat[i],1ll*x.mat[j]*y.mat[j][i]%mod);
	return res;
}
int ans[maxq],L[maxq],R[maxq];
void solve(int l,int r,vector<int>q){
	if(q.empty())return;
	if(l==r){
		for(int i:q){
			A=Vector(),A.mat[0]=1;
			A=A*Z[a[l]],ans[i]=A.getval();
		}
		return;
	}
	int mid=(l+r)>>1;
	B[mid]=Z[a[mid]],B[mid+1]=Z[a[mid+1]];
	per(i,mid-1,l)B[i]=Z[a[i]]^B[i+1];
	rep(i,mid+2,r)B[i]=B[i-1]*Z[a[i]];
	vector<int>q1,q2;
	for(int i:q){
		if(R[i]<=mid)q1.pb(i);
		else if(L[i]>mid)q2.pb(i);
		else{
			A=Vector(),A.mat[0]=1;
			A=(A*B[L[i]])*B[R[i]],ans[i]=A.getval();
		}
	}
	solve(l,mid,q1),solve(mid+1,r,q2);
}
inline void solve_the_problem(){
	n=rd(),scanf("%s",s+1);
	rep(i,1,n)a[i]=(s[i]=='?'?2:s[i]-'0');
	rep(i,0,1){
		rep(a,0,1)rep(b,0,1)rep(c,0,1)rep(d,0,1){
			int S=(a<<3)|(b<<2)|(c<<1)|d;
			int e=a,f=b,g=c,h=d;
			if(!i)swap(e,g),e^=1;
			else swap(f,g),f^=1;
			h=(h+g)%2;
			int T=(e<<3)|(f<<2)|(g<<1)|h;
			Z[i].mat[S][T]=1;
		}
	}
	Z[2]=Z[0]+Z[1];
	Q=rd();
	rep(i,1,Q)L[i]=rd(),R[i]=rd();
	vector<int>ALL;
	rep(i,1,Q)ALL.pb(i);
	solve(1,n,ALL);
	rep(i,1,Q)write(ans[i]);
}
bool Med;
signed main(){
//	freopen(".in","r",stdin);freopen(".out","w",stdout);
	fprintf(stderr,"%.3lfMB\n",(&Mbg-&Med)/1048576.0);
	int _=1;
	while(_--)solve_the_problem();
}
/*
5
10001
1
1 5
*/

posted @ 2025-09-18 21:09  dcytrl  阅读(25)  评论(1)    收藏  举报