题解:CF815D Karen and Cards
前言
有点困难题。
思路分析
考虑本质上是三维问题,考虑降维处理。
因为全部满足条件不好做,考虑正难则反,计算不合法的三元组个数。
首先对 \(a_i\) 排序,从大往小做扫描线,每次维护 \(b=i\) 时 \(c\) 的最大不合法值,那么对于每一个 \(a\),不合法的三元组数量等于全局和。
最开始,对于每个限制 \((a_i,b_i,c_i)\),设线段树维护的序列为 \(d\),相当于 \(\max_{j=1}^{b_i} d_j\le c_i\),因为只需要满足一个,所以在区间 \([1,b_i]\) 取 \(\max\) 即可。
每次扫到限制 \((a_i,b_i,c_i)\) 时,\([1,b_i]\) 的部分合法,而 \([b_i+1,q]\) 的部分需要满足 \(\max_{j=b_i+1}^{q} d_j \le c_i\),相当于两次区间取 \(\max\)。
为了区间取 \(\max\) 写 seg_beat 肯定不值当,注意到我们维护的序列 \(d\) 是有单调性的,所以直接线段树二分 + 区间推平就行。
总体复杂度 \(O(v \log (n+v))\)。
代码实现
#include<bits/stdc++.h>
#define int long long
using namespace std;
int n,p,q,r,ans,now,minb[500005],minc[500005],d[500005];
struct node{
int a,b,c;
}h[500005];
int val_sum[1000005],val_min[1000005],tag[1000005],ls[1000005],rs[1000005],dcnt,rt;
void pushup(int x){
val_sum[x]=val_sum[ls[x]]+val_sum[rs[x]];
val_min[x]=min(val_min[ls[x]],val_min[rs[x]]);
}
void pushdown(int l,int r,int x){
if(!tag[x]) return;
int mid=(l+r)>>1;
tag[ls[x]]=tag[rs[x]]=val_min[ls[x]]=val_min[rs[x]]=tag[x];
val_sum[ls[x]]=(mid-l+1)*tag[x];
val_sum[rs[x]]=(r-mid)*tag[x];
tag[x]=0;
}
void build(int l,int r,int &x){
x=++dcnt;
if(l==r) return;
int mid=(l+r)>>1;
build(l,mid,ls[x]);
build(mid+1,r,rs[x]);
pushup(x);
}
void modify(int l,int r,int ql,int qr,int k,int x){
if(ql<=l && r<=qr){
val_min[x]=tag[x]=k;
val_sum[x]=(r-l+1)*k;
return;
}
pushdown(l,r,x);
int mid=(l+r)>>1;
if(ql<=mid) modify(l,mid,ql,qr,k,ls[x]);
if(qr>=mid+1) modify(mid+1,r,ql,qr,k,rs[x]);
pushup(x);
}
int findl(int l,int r,int ql,int qr,int k,int x){
if(l==r){
if(val_min[x]>k) return -1;
else return l;
}
pushdown(l,r,x);
int mid=(l+r)>>1,ans=-1;
if(ql<=mid){
if(val_min[ls[x]]<=k) ans=findl(l,mid,ql,qr,k,ls[x]);
if(ans!=-1) return ans;
}
if(qr>=mid+1){
if(val_min[rs[x]]<=k) ans=findl(mid+1,r,ql,qr,k,rs[x]);
if(ans!=-1) return ans;
}
return ans;
}
void add(int l,int r,int k){
int p=findl(1,q,l,r,k,rt);
if(p!=-1) modify(1,q,p,r,k,rt);
}
void init(){
for(int i=1;i<=n;i++){
add(1,h[i].b,h[i].c);
}
}
bool cmp(node a,node b){
return a.a<b.a;
}
signed main(){
ios::sync_with_stdio(false);
cin.tie(0);
cout.tie(0);
cin>>n>>p>>q>>r;
for(int i=1;i<=n;i++){
cin>>h[i].a>>h[i].b>>h[i].c;
}
sort(h+1,h+1+n,cmp);
build(1,q,rt);
init();
now=n;
for(int i=p;i>=1;i--){
while(now && h[now].a>=i){
add(h[now].b+1,q,h[now].c);
add(1,h[now].b,r);
now--;
}
ans+=val_sum[1];
}
cout<<p*q*r-ans;
return 0;
}

浙公网安备 33010602011771号