CF1148 H. Holy Diver
题目叙述
长度为 \(n\) 的序列,每次向后添加一个数然后询问区间 \([l,r]\) 的子区间中 mex 为 \(k\) 的区间一共有多少个。要求 \(\mathcal O(n\log_2 n)\)
题解
思路
区间子区间问题。考虑维护每个点作为右端点,每个左端点的 mex 是多少。用一个数据结构维护贡献,求历史和这样的东西即可。
但这个题需要查询 mex 为 \(k\) 的 \([l,r]\) 的子区间有多少个。因此改成对于每个 \(k\) 开一个线段树,查询区间 \([l,r]\) 的历史和。
然后考虑怎么维护 mex 。如果设 \(las_i\) 表示 \(i\) 的上一次出现位置,容易发现 mex 的分段情况相当于值域上的单调栈。每次从 \(r\) 变为 \(r+1\) ,相当于把 \(a_{r+1}\) 的上一次出现位置改为 \(r+1\) 。用 set 维护单调栈,每次修改需要把因为 \(a_{r+1}\) 很小而删掉的数加入即可。需要使用线段树二分维护 last 数组中在 \(r+1\) 后面且比 \(a_{r+1}\) 小的第一个数是多少。set 修改的同时在求历史和的线段树上相应操作即可。
如何实现求历史和
考虑区间加带来的影响。应当把它写成算式的形式而不是放到线段树上当一种标记处理。会很麻烦。当然也得试着做一下,因为这是似乎是可行的。
首先这个贡献可以拆分为一个数一个数地看,其次对于一个数而言,当前时刻 \(+c\) 对时间 \(t\) 的贡献相当于,如果后面还要经过 \(t\) 的时间,那么就会贡献 \(tc\) 。这相当于查询的时间与加入这个标记的时间差。因此只要维护每个标记与时间的乘积之和就可以了还有目前的区间和就可以了。
总结
区间子区间问题\(\rightarrow\)
区间mex\(\rightarrow\) 单调栈。
另外要观察问题的特征。比如单调栈里的元素只增不减。
尽量把答案通过算式化简然后再维护,不要吧一大堆东西都放到标记上。
代码
#include <cstdio>
#include <iostream>
#include <set>
#include <algorithm>
#include <vector>
#include <cassert>
using namespace std;
typedef long long LL;
const int MN=2e5+5;
namespace ST1{
int minv[MN*4];
void upd(int o){minv[o]=min(minv[o<<1],minv[o<<1|1]);}
void modify(int o,int l,int r,int p,int v){
if(l==r)return minv[o]=v,void();
int mid=(l+r)>>1;
if(p<=mid)modify(o<<1,l,mid,p,v);
else modify(o<<1|1,mid+1,r,p,v);
upd(o);
}
// 找到下一个 <=v 的值,找不着返回N+1
int find_next(int o,int l,int r,int p,int v){
if(r<=p)return -1;
if(minv[o]>=v)return -1;
if(l==r)return l;
int mid=(l+r)>>1;
int la=find_next(o<<1,l,mid,p,v);
if(la!=-1)return la;
else return find_next(o<<1|1,mid+1,r,p,v);
}
}
int N;
namespace ST2{
const int MND=MN*4*20;
int ls[MND],rs[MND],totnode;
LL s1[MND],s2[MND],t1[MND],t2[MND];
void new_node(int &o){
++totnode;
ls[totnode]=ls[o],rs[totnode]=rs[o];
s1[totnode]=s1[o],s2[totnode]=s2[o];
t1[totnode]=t1[o],t2[totnode]=t2[o];
o=totnode;
}
void modify(int &o,int l,int r,int ql,int qr,int v1,int v2){
if(l>qr||r<ql)return;
new_node(o);
if(l>qr||r<ql)return;
if(ql<=l&&r<=qr)return s1[o]+=(LL)(r-l+1)*v1,s2[o]+=(LL)(r-l+1)*v2,t1[o]+=v1,t2[o]+=v2,void();
int mid=(l+r)>>1;
modify(ls[o],l,mid,ql,qr,v1,v2),modify(rs[o],mid+1,r,ql,qr,v1,v2);
s1[o]=s1[ls[o]]+s1[rs[o]]+(LL)(r-l+1)*t1[o],s2[o]=s2[ls[o]]+s2[rs[o]]+(LL)(r-l+1)*t2[o];
}
#define fi first
#define se second
pair<LL,LL> query(int o,int l,int r,int ql,int qr){
if(!o)return make_pair(0,0);
if(l>qr||r<ql)return make_pair(0,0);
if(ql<=l&&r<=qr)return make_pair(s1[o],s2[o]);
int mid=(l+r)>>1;
auto al=query(ls[o],l,mid,ql,qr),ar=query(rs[o],mid+1,r,ql,qr);
LL len=min(qr,r)-max(l,ql)+1;
return make_pair(al.fi+ar.fi+len*t1[o],al.se+ar.se+len*t2[o]);
}
vector<int> pos[MN],rt[MN];
void append(int k,int tim,int l,int r,int w1,int w2){
pos[k].push_back(tim),rt[k].push_back(0);
int siz=rt[k].size();
rt[k][siz-1]=((siz>=2)?rt[k][siz-2]:0);
modify(rt[k][siz-1],1,N,l,r,w1,w2);
}
LL query(int k,int l,int r){
int p=upper_bound(pos[k].begin(),pos[k].end(),r)-pos[k].begin()-1;
if(p<0)return 0;
auto tmp=query(rt[k][p],1,N,l,r);
return tmp.fi*(r+1)-tmp.se;
}
}
set<int> s;
int last[MN];
int next_val(set<int>::iterator it){return *(++it);}
int prev_val(set<int>::iterator it){return *(--it);}
int main(){
freopen("h.in","r",stdin);
// freopen(".out","w",stdout);
scanf("%d",&N);
s.insert(0);
LL lastans=0;
for(int i=1;i<=N;++i){
int a=0,l=0,r=0,k=0;
scanf("%d%d%d%d",&a,&l,&r,&k);
a=(a+lastans)%(N+1),l=(l+lastans)%i+1,r=(r+lastans)%i+1,k=(k+lastans)%(N+1);
if(l>r)swap(l,r);
// 维护的是 last 数组
ST1::modify(1,0,N,a,i);
int tmp=last[a];
last[a]=i;
if(s.count(a)){
auto p=s.find(a);
if(p!=s.begin()){
int pv=prev_val(p);
ST2::append(a,i,tmp+1,last[pv],-1,-i);
--p,s.erase(a);
}else ST2::append(a,i,tmp+1,i-1,-1,-i);
{
auto q=p;
++q;
if(q==s.end())ST2::append(a+1,i,1,tmp,-1,-i);
else ST2::append(*q,i,last[*q]+1,tmp,-1,-i);
}
int v=*p;++p;
if(p==s.end()){
while(1){
int la=v;
v=ST1::find_next(1,0,N,v,last[v]);
if(v==-1){
ST2::append(la+1,i,1,last[la],1,i);
break;
}
ST2::append(v,i,last[v]+1,last[la],1,i);
s.insert(v);
}
}else{
int t=*p;
while(1){
int la=v;
v=ST1::find_next(1,0,N,v,last[v]);
ST2::append(v,i,last[v]+1,last[la],1,i);
if(v==t)break;
s.insert(v);
}
}
}
if(a>=1)ST2::append(0,i,i,i,1,i);
printf("%lld\n",lastans=ST2::query(k,l,r));
}
fclose(stdin);
// fclose(stdout);
return 0;
}

浙公网安备 33010602011771号