CodeForces 1327F AND Segments
题意
给三个整数 \(n,k,m\) 和 \(m\) 个限制 \((l_i,r_i,x_i)\),求有多少个长度为 \(n\) 的序列 \(a\) 满足:
-
对于 \(1\leq i\leq n\) 有 \(0\leq a_i<2^k\)
-
对于 \(1\leq i\leq m\) 有 \(a_{l_i} \operatorname{and} a_{l_i+1}\operatorname{and}\cdots\operatorname{and} a_{r_i}=x_i\)
对 \(998244353\) 取模。
\(\texttt{Data Range:}1\leq n\leq 5\times 10^5,1\leq k\leq 30,0\leq m\leq 5\times 10^5\)
题解
毒瘤题。
一个非常显然的想法是拆位,所以变成每个位置填 \(0\) 或 \(1\) 然后满足所有条件的限制的方案数,总的方案数就是每一位的方案数乘起来就好了。
如果一段区间限制为 \(1\) 的话那么所有数都必须填 \(1\),如果限制是 \(0\) 的话那么至少有一个是 \(0\)。
设 \(f_{i,j}\) 表示当前在位置 \(i\),最后一个 \(0\) 在位置 \(j\) 的方案数,然后你会发现这个东西不好做。
考虑设一个 \(p_i\) 表示 \(i\) 位置前(不包括 \(i\) 位置)第一个 \(0\) 最小能填到哪个位置。
当 \(j<p_i\) 的时候很明显 \(f_{i,j}=0\)。
当 \(p_i\leq j<i\) 的时候,因为 \(i\) 位置没有填,所以 \(f_{i,j}=f_{i-1,j}\)。
当 \(j=i\) 的时候,如果这个位置强制选 \(1\) 的话那么 \(f_{i,j}=0\),否则枚举一下上一个 \(0\) 的位置得到 \(f_{i,j}=\sum\limits_{k<j}f_{i-1,k}\)。
注意到 \(i\) 这一维可以滚掉,而 \(p_i\) 又是单调不降的,所以可以考虑用一个指针来维护一下满足 \(f_{i,j}\neq 0\) 的最小的 \(j\)。
至于第三种操作,因为当 \(i<j\) 的时候 \(f_{i,j}=0\),所以可以直接维护当前 \(i\) 的所有 \(f_{i,j}\) 的和即可。
然后处理出哪个位置要强制选 \(1\) 的话可以对 \(1\) 的限制涉及到的区间做区间加,可以差分一下再前缀和一下。
处理 \(p_i\) 可以考虑每个为 \(0\) 的限制 \((l,r,0)\),记 \(p_{r+1}=l\) 即可。
代码
#include<bits/stdc++.h>
using namespace std;
typedef int ll;
typedef long long int li;
const ll MAXN=5e5+51,MOD=998244353;
ll n,kk,m,res=1,sum,ptr;
ll l[MAXN],r[MAXN],x[MAXN],pos[MAXN],sel[MAXN],f[MAXN];
inline ll read()
{
register ll num=0,neg=1;
register char ch=getchar();
while(!isdigit(ch)&&ch!='-')
{
ch=getchar();
}
if(ch=='-')
{
neg=-1;
ch=getchar();
}
while(isdigit(ch))
{
num=(num<<3)+(num<<1)+(ch-'0');
ch=getchar();
}
return num*neg;
}
inline void calc(ll bit)
{
for(register int i=1;i<=m;i++)
{
if(x[i]&(1<<bit))
{
sel[l[i]]++,sel[r[i]+1]--;
}
else
{
pos[r[i]+1]=max(pos[r[i]+1],l[i]);
}
}
f[0]=sum=1,ptr=0;
for(register int i=2;i<=n+1;i++)
{
sel[i]+=sel[i-1],pos[i]=max(pos[i],pos[i-1]);
}
for(register int i=1;i<=n+1;i++)
{
for(;ptr<pos[i];sum=(sum-f[ptr]+MOD)%MOD,f[ptr++]=0);
f[i]=sel[i]?0:sum,sum=(sum+f[i])%MOD;
}
res=(li)res*f[n+1]%MOD;
for(register int i=0;i<=n+1;i++)
{
sel[i]=pos[i]=f[i]=0;
}
}
int main()
{
n=read(),kk=read(),m=read();
for(register int i=1;i<=m;i++)
{
l[i]=read(),r[i]=read(),x[i]=read();
}
for(register int i=0;i<kk;i++)
{
calc(i);
}
printf("%d\n",res);
}