ABC265G
Solution
区间修改和区间查询,考虑线段树。
由于所有数 \(\in\{0,1,2\}\),对于一个区间,维护以下信息:
-
\(ans(i,j)\) 表示形如 \((i,j)\) 这样的逆序对。
-
\(sum(i)\) 表示数字 \(i\) 出现的个数。
-
\(s,t,u\) 的 \(tag\) 标记。
上传标记很简单,\(sum\) 就直接加,\(ans(i,j)\) 就是左区间 \(ans\) 和右区间 \(ans\) 和左区间 \(sum(i)\) 乘右区间 \(sum(j)\)。
void pushup(int p){
for(int i=0;i<3;i++)sum[p][i]=sum[ls[p]][i]+sum[rs[p]][i];
for(int i=0;i<3;i++)for(int j=0;j<3;j++)
ans[p][i][j]=ans[ls[p]][i][j]+ans[rs[p]][i][j]+1ll*sum[ls[p]][i]*sum[rs[p]][j];
}
主要是怎么处理标记,注意到原序列 \(0\to s,1\to t,2\to u\),仅仅涉及到一个交换的问题,只需要开一个临时变量,然后将其赋回来即可。
tmp[0]=s;tmp[1]=t;tmp[2]=u;
for(int i=0;i<3;i++){
a[tmp[i]]+=sum[p][i];//注意是+,因为没有保证 s,t,u 互不相等
for(int j=0;j<3;j++)res[tmp[i]][tmp[j]]+=ans[p][i][j];
}
for(int i=0;i<3;i++)sum[p][i]=a[i];
for(int i=0;i<3;i++)for(int j=0;j<3;j++)ans[p][i][j]=res[i][j];
下传标记要注意,并非直接覆盖,例如原来有 \(tags=1\),现在多一个 \(tagt=2\),等价于 \(tags=2\)。也就是说我们得看上一次变成了什么,结合代码理解。
tags[p]=tmp[tags[p]];
tagt[p]=tmp[tagt[p]];
tagu[p]=tmp[tagu[p]];//tmp 的定义同上一段代码
查询就是将多个区间合并,区间内的直接加 \(ans\),跨区间的需要记一下前面的区间出现了多少个 \(1,2\),然后将能产生贡献的数乘起来加入答案即可。
Code
细节多,有点难写,2.78KB 成功压到了 99 行。
#include<bits/stdc++.h>
using namespace std;
#define pb push_back
#define mkp make_pair
#define fi first
#define se second
using ll=long long;
using pii=pair<int,int>;
using pll=pair<ll,ll>;
using ull=unsigned long long;
inline void read(int &x){
char ch=getchar();
int r=0,w=1;
while(!isdigit(ch))w=ch=='-'?-1:1,ch=getchar();
while(isdigit(ch))r=(r<<1)+(r<<3)+(ch^48),ch=getchar();
x=r*w;
}
const int N=1e5+7;
int n,q,x,lst1,lst2;
ll ss;
struct seg_tree{
int root,tot,sum[N*4][3],tags[N*4],tagt[N*4],tagu[N*4],ls[N*4],rs[N*4];
ll ans[N*4][3][3];
void pushup(int p){
for(int i=0;i<3;i++)sum[p][i]=sum[ls[p]][i]+sum[rs[p]][i];
for(int i=0;i<3;i++)for(int j=0;j<3;j++)
ans[p][i][j]=ans[ls[p]][i][j]+ans[rs[p]][i][j]+1ll*sum[ls[p]][i]*sum[rs[p]][j];
}
void update(int p,int s,int t,int u){
int a[3],tmp[3];
ll res[3][3];
for(int i=0;i<3;i++)a[i]=0;
for(int i=0;i<3;i++)for(int j=0;j<3;j++)res[i][j]=0;
tmp[0]=s;tmp[1]=t;tmp[2]=u;
for(int i=0;i<3;i++){
a[tmp[i]]+=sum[p][i];
for(int j=0;j<3;j++)res[tmp[i]][tmp[j]]+=ans[p][i][j];
}
for(int i=0;i<3;i++)sum[p][i]=a[i];
for(int i=0;i<3;i++)for(int j=0;j<3;j++)ans[p][i][j]=res[i][j];
if(tags[p]==-1)tags[p]=s,tagt[p]=t,tagu[p]=u;
else{
tags[p]=tmp[tags[p]];
tagt[p]=tmp[tagt[p]];
tagu[p]=tmp[tagu[p]];
}
}
void pushdown(int p){
update(ls[p],tags[p],tagt[p],tagu[p]);
update(rs[p],tags[p],tagt[p],tagu[p]);
tags[p]=tagt[p]=tagu[p]=-1;
}
void pre(int &p,int l=1,int r=n){
if(!p)p=++tot;
if(l==r){read(x),sum[p][x]++;return;}
int mid=l+r>>1;
pre(ls[p],l,mid);pre(rs[p],mid+1,r);
pushup(p);tags[p]=tagt[p]=tagu[p]=-1;
}
void modify(int &p,int ql,int qr,int s,int t,int u,int l=1,int r=n){
if(!p)p=++tot;
if(ql<=l&&r<=qr){
update(p,s,t,u);
return;
}
if(tags[p]!=-1)pushdown(p);
int mid=l+r>>1;
if(ql<=mid)modify(ls[p],ql,qr,s,t,u,l,mid);
if(qr>mid)modify(rs[p],ql,qr,s,t,u,mid+1,r);
pushup(p);
}
ll query(int &p,int ql,int qr,int l=1,int r=n){
if(!p)p=++tot;
if(ql<=l&&r<=qr){
ss+=lst2*(sum[p][0]+sum[p][1]);
ss+=lst1*sum[p][0];
lst1+=sum[p][1];lst2+=sum[p][2];
return ans[p][2][0]+ans[p][2][1]+ans[p][1][0];
}
if(tags[p]!=-1)pushdown(p);
int mid=l+r>>1;
ll tmp=0;
if(ql<=mid)tmp+=query(ls[p],ql,qr,l,mid);
if(qr>mid)tmp+=query(rs[p],ql,qr,mid+1,r);
return tmp;
}
}tr;
int main(){
read(n);read(q);tr.pre(tr.root);
while(q--){
int op,l,r,s,t,u;read(op);read(l);read(r);
if(op==1){
ss=lst1=lst2=0;
printf("%lld\n",tr.query(tr.root,l,r)+ss);
}
else read(s),read(t),read(u),tr.modify(tr.root,l,r,s,t,u);
}
return 0;
}