【2025.10.17NOIP模拟】位集(wj)
题目描述
定义大小为 m 的 bitset 为长度为 m 的 bool 数组。
对大小为 m 的 bitset 定义如下四种运算:
- c=a and b:在这里,如果 ai=1 且 bi=1,则 ci=1;否则 ci=0。
- c=a or b:在这里,如果 ai=1 或 bi=1,则 ci=1;否则 ci=0。
- c=a xor b:在这里,如果 ai 和 bi 中恰好有一个为 1,则 ci=1;否则 ci=0。
- c=not a:在这里,如果 ai=0,则 ci=1;否则 ci=0。
给定一个大小为 n 的 bitset 数组 s1,s2,…,sn,编写程序来回答 k 个查询,每次查询给定 l,r,你需要使用以下公式计算 t:
- t=(sl and sl+1 and ⋯ and sr) xor (not (sl or sl+1 or ⋯ or sr))
求 t 中 1 的个数。
输入
从文件 wj.in 中读入数据。
第一行包含两个整数 n 和 m (1≤n,m≤105; n⋅m≤106)。接下来的 n 行描述了 n 个 bitset,每行由 m 个 0 或 1 组成,表示一个 bitset。
接下来的一行包含一个整数 k,表示查询的数量 (1≤k≤2×106)。
接下来的一行包含三个整数 x,y,z (1≤x,y,z≤109)。
查询是通过以 x,y,z 为参数的伪随机算法生成的,具体来说,考虑生成长度为 k 的序列 a,b:
- a1=1。
- b1=n。
- 对于 i>1,ai=(ai−1⋅x+qi−1⋅y+z)modn+1。
- 对于 i>1,bi=(bi−1⋅y+qi−1⋅z+x)modn+1。
其中,第 i 个询问的 l 是 min{ai,bi},r 是 max{ai,bi},公式里的 qi−1 表示第 i−1 个询问的答案。
输出
输出到文件 wj.out 中。
输出一个整数表示所有查询答案的总和。
样例数据
输入 #1 复制
4 10 1010110101 0101111001 1101101101 1011010000 4 10 5 4
输出 #1 复制
9
数据范围限制
对于所有数据,有:
- 1≤n,m≤105
- nm≤106
- 1≤k≤2×106
- 1≤x,y,z≤109
子任务:
| 子任务编号 | 特殊性质 | 分值 |
|---|---|---|
| 1 | n,m≤20,k≤50 | 40 |
| 2 | m=1 | 20 |
| 3 | k≤1×105 | 20 |
| 4 | y=z=0 | 10 |
| 5 | 无 | 10 |
提示
样例解释
| 询问编号 | l | r | 答案 |
|---|---|---|---|
| 1 | 1 | 4 | 1 |
| 2 | 3 | 4 | 3 |
| 3 | 2 | 4 | 2 |
| 4 | 1 | 3 | 3 |
思路
首先,看到区间操作,考虑线段树维护。
可以发现,其实是维护全是1数量与全是0数量之和。
考虑线段树,对每种颜色区间合并,复杂度O(nmlogn+qmlogn):
代码见下O(nmlogn+qmlogn):
#include<bits/stdc++.h>
using namespace std;
int n,m,q,x,y,z,l,r,a1,b1,nm,lk=0,kl=0;
struct one{
vector<bool> v;
}cd;
char s[1000006];
int te[400005],te2[4000006];
int te1[400005],te12[4000006];
inline void bu(int a1,int l,int r){
if(l==r){
int dm=0;
for(int i=1;i<=m;i++){
if(s[(l-1)*m+i]=='1'){
te[a1]++;
}
te2[(i-1)*nm+a1]=(int)(s[(l-1)*m+i]-'0');
}
return ;
}
int mid=(l+r)/2;
bu(a1*2,l,mid);
bu(a1*2+1,mid+1,r);
int dm=0;
for(int i=1;i<=m;i++){
if(te2[(i-1)*nm+a1*2]==1&&te2[(i-1)*nm+a1*2+1]==1){
te[a1]++;
te2[(i-1)*nm+a1]=1;
}
else{
te2[(i-1)*nm+a1]=0;
}
}
return ;
}
inline one co(int a1,int l,int r,int x,int y){
if(l>=x&&r<=y){
one dbdb;
dbdb.v.resize(m);
for(int i=1;i<=m;i++){
dbdb.v[i-1]=te2[(i-1)*nm+a1];
}
return dbdb;
}
int mid=(l+r)/2,utut=0;
one dbdb,ckck;
dbdb.v.resize(m);
if(x<=mid){
utut=1;
dbdb=co(a1*2,l,mid,x,y);
}
if(y>=mid+1){
if(utut==0){
dbdb=co(a1*2+1,mid+1,r,x,y);
}
else{
ckck.v.resize(m);
ckck=co(a1*2+1,mid+1,r,x,y);
for(int i=0;i<m;i++){
if(dbdb.v[i]==1&&ckck.v[i]==1){
dbdb.v[i]=1;
}
else{
dbdb.v[i]=0;
}
}
}
}
return dbdb;
}
inline void bu2(int a1,int l,int r){
if(l==r){
int dm=0;
for(int i=1;i<=m;i++){
if(s[(l-1)*m+i]=='0'){
te1[a1]++;
}
te12[(i-1)*nm+a1]=(int)(s[(l-1)*m+i]-'0');
}
return ;
}
int mid=(l+r)/2;
bu2(a1*2,l,mid);
bu2(a1*2+1,mid+1,r);
int dm=0;
for(int i=1;i<=m;i++){
if(te12[(i-1)*nm+a1*2]==0&&te12[(i-1)*nm+a1*2+1]==0){
te1[a1]++;
te12[(i-1)*nm+a1]=0;
}
else{
te12[(i-1)*nm+a1]=1;
}
}
return ;
}
inline one co2(int a1,int l,int r,int x,int y){
if(l>=x&&r<=y){
one dbdb;
dbdb.v.resize(m);
for(int i=1;i<=m;i++){
dbdb.v[i-1]=te12[(i-1)*nm+a1];
}
return dbdb;
}
int mid=(l+r)/2,utut=0;
one dbdb,ckck;
dbdb.v.resize(m);
if(x<=mid){
utut=1;
dbdb=co2(a1*2,l,mid,x,y);
}
if(y>=mid+1){
if(utut==0){
dbdb=co2(a1*2+1,mid+1,r,x,y);
}
else{
ckck.v.resize(m);
ckck=co2(a1*2+1,mid+1,r,x,y);
for(int i=0;i<m;i++){
if(dbdb.v[i]==0&&ckck.v[i]==0){
dbdb.v[i]=0;
}
else{
dbdb.v[i]=1;
}
}
}
}
return dbdb;
}
int main(){
ios::sync_with_stdio(false);
cin.tie(0);
cout.tie(0);
cin>>n>>m;
for(int i=1;i<=n;i++){
for(int j=1;j<=m;j++){
cin>>s[(i-1)*m+j];
}
}
nm=4*n;
bu(1,1,n);
bu2(1,1,n);
a1=1;
b1=n;
cin>>q;
cin>>x>>y>>z;
lk=0;
cd.v.resize(m);
for(int i=1;i<=q;i++){
l=min(a1,b1);
r=max(a1,b1);
//cout<<l<<" "<<r<<endl;
cd=co(1,1,n,l,r);
kl=0;
for(int j=0;j<m;j++){
if(cd.v[j]==1){
lk++;
kl++;
}
}
cd=co2(1,1,n,l,r);
for(int j=0;j<m;j++){
if(cd.v[j]==0){
lk++;
kl++;
}
}
//cout<<kl<<endl;
a1=(a1*x+kl*y+z)%n+1;
b1=(b1*y+kl*z+x)%n+1;
}
cout<<lk<<'\n';
return 0;
}
注意到全是1数量与全是0数量之和,考虑前缀和,最后查询m次,去除logn,复杂度O(nm+qm):
代码见下O(nm+qm):
#include<bits/stdc++.h>
using namespace std;
int n,m,q,x,y,z,l,r,a1,b1,nm,lk=0,kl=0;
char s[1000006];
int ss[1000006];
int main(){
ios::sync_with_stdio(false);
cin.tie(0);
cout.tie(0);
cin>>n>>m;
for(int i=1;i<=n;i++){
for(int j=1;j<=m;j++){
cin>>s[(i-1)*m+j];
}
}
for(int i=1;i<=n;i++){
for(int j=1;j<=m;j++){
if(i==1){
ss[(j-1)*n+i]=(long long)(s[(i-1)*m+j]-'0');
}
else{
ss[(j-1)*n+i]=ss[(j-1)*n+i-1]+(long long)(s[(i-1)*m+j]-'0');
}
}
}
a1=1;
b1=n;
cin>>q;
cin>>x>>y>>z;
lk=0;
for(int i=1;i<=q;i++){
l=min(a1,b1);
r=max(a1,b1);
kl=0;
for(int j=1;j<=m;j++){
if(l==1){
if(ss[(j-1)*n+r]==r||ss[(j-1)*n+r]==0){
lk++;
kl++;
}
}
else{
if(ss[(j-1)*n+r]-ss[(j-1)*n+l-1]==r-l+1||ss[(j-1)*n+r]-ss[(j-1)*n+l-1]==0){
lk++;
kl++;
}
}
}
//cout<<l<<" "<<r<<endl;
//cout<<kl<<endl;
a1=(a1*x+kl*y+z)%n+1;
b1=(b1*y+kl*z+x)%n+1;
}
cout<<lk<<'\n';
return 0;
}
发现qm会超时,思考如何消除m,可以查找对于每一个位置,从它开始,至多到多少才能达到a,复杂度O(nmlogn+qlogm):
代码见下
#include<bits/stdc++.h>
using namespace std;
long long n,m,q,x,y,z,l,r,a1,b1,nm,lk=0,kl=0;
char s[1000006];
long long ss[1000006],dl[1000006],sd[100005],sf=0;
int main(){
// freopen("wj.in","r",stdin);
// freopen("wj.out","w",stdout);
cin>>n>>m;
for(int i=1;i<=n;i++){
for(int j=1;j<=m;j++){
cin>>s[(i-1)*m+j];
}
}
for(int i=1;i<=n;i++){
for(int j=1;j<=m;j++){
if(i==1){
ss[(j-1)*n+i]=(long long)(s[(i-1)*m+j]-'0');
}
else{
ss[(j-1)*n+i]=ss[(j-1)*n+i-1]+(long long)(s[(i-1)*m+j]-'0');
}
}
}
for(int j=1;j<=m;j++){
sf=0;
for(int i=1;i<=n;i++){
if(i!=1){
while(sf+1<=n&&(ss[(j-1)*n+sf+1]-ss[(j-1)*n+i-1]==0||ss[(j-1)*n+sf+1]-ss[(j-1)*n+i-1]==sf+1-i+1)){
sf++;
}
}
else{
while(sf+1<=n&&(ss[(j-1)*n+sf+1]==0||ss[(j-1)*n+sf+1]==sf+1-i+1)){
sf++;
}
}
dl[(i-1)*m+j]=sf;
//cout<<j<<" "<<i<<" "<<sf<<" "<<ss[(j-1)*n+sf]<<" "<<ss[(j-1)*n+i]<<endl;
}
}
for(int i=1;i<=n;i++){
sort(dl+(i-1)*m+1,dl+i*m+1);
// for(int j=1;j<=m;j++){
// cout<<dl[(i-1)*m+j]<<" ";
// }
// cout<<endl;
}
a1=1;
b1=n;
cin>>q;
cin>>x>>y>>z;
lk=0;
for(int i=1;i<=q;i++){
l=min(a1,b1);
r=max(a1,b1);
long long ll=1,rr=m;
kl=m+1;
while(ll<=rr){
long long mid=(ll+rr)/2;
if(r<=dl[(l-1)*m+mid]){
kl=min(kl,mid);
rr=mid-1;
}
else{
ll=mid+1;
}
//cout<<mid<<" "<<dl[(l-1)*m+mid]<<endl;
}
kl=m-kl+1;
lk+=kl;
//cout<<kl<<endl;
a1=(a1*x+kl*y+z)%n+1;
b1=(b1*y+kl*z+x)%n+1;
}
cout<<lk<<'\n';
return 0;
}

浙公网安备 33010602011771号