整体二分学习笔记
引入
整体二分可以一次性地解决多次询问的二分,一般配合数据结构使用。
可以使用整体二分解决的题目需要满足以下性质:
询问的答案具有可二分性
修改对判定答案的贡献互相独立,修改之间互不影响效果
修改如果对判定答案有贡献,则贡献为一确定的与判定标准无关的值
贡献满足交换律,结合律,具有可加性
题目允许使用离线算法
——许昊然《浅谈数据结构题几个非经典解法》
先从普通的二分说起。
例1:查询全局第 \(k\) 大
显然,这可以通过简单的排序在 \(O(n\log n)\) 时间内解决。如果使用二分,则可以开桶维护,然后二分答案,每次查询有多少个数比当前答案要大来进行 \(check\)。
例2:多次查询全局第 \(k\) 大
这里排序的复杂度是 \(O(n\log n+T)\) 的。如果每次询问都进行一次二分答案,总复杂度就是 \(O(n\log n+T\log ans)\) 的。
例3:多次查询区间第 \(k\) 大(主席树模板)
因为要维护区间信息,所以这题排序无法解决。同样地,直接二分会使 \(check\) 的复杂度爆炸。
于是我们可以使用整体二分来优化直接二分。
简介
整体二分就是一个把所有询问组成的集合按照答案大小划分的过程。
假设所有询问的答案为 \(Q_1,Q_2,Q_3,...,Q_n\),答案范围为 \([l,r]\)。
二分答案,当前答案 \(mid=\lfloor\frac{l+r}{2}\rfloor\)。
对于所有 \(Q_i,i\in[1,n]\) 进行 \(check\),将 \(\le mid\) 的答案归为一组,\(>mid\) 的答案归为另一组。
然后对两组 \([l,mid]\) 和 \([mid+1,r]\) 递归地进行处理即可。
递归树高 \(\log(n)\),每层对于所有询问进行一次 \(check\)。设 \(check\) 所有询问的复杂度为 \(O(T)\),则总复杂度 \(O(T\log(n))\)。在例3中,\(O(T)=O(n\log(n))\),于是总复杂度为 \(O(n\log^2(n))\)。
例题
洛谷P1527 [国家集训队] 矩阵乘法
考虑对每次询问二分答案。对于小于当前答案的数将其标记为 \(1\),大于等于的标记为 \(0\),则可以用二维树状数组统计矩阵内 \(1\) 的个数,即矩阵和。将所有询问整体二分即可。时间复杂度 \(O((N^2+Q)*\log ^3N)\)
#include<bits/stdc++.h>
using namespace std;
constexpr int N=5e2+2,M=6e4+4;
int n,m,tmp[N*N],len,ans[M];
struct Mat{int x,y,val;}p[N*N];
constexpr inline bool operator<(Mat x,Mat y){return x.val<y.val;}
struct Query{int x1,x2,y1,y2,k;}q[M];
struct BIT{
#define lowbit(x) (x&-x)
int sum[N][N];
void update(int x,int y,int val){
for(int i=x;i<=n;i+=lowbit(i)){
for(int j=y;j<=n;j+=lowbit(j)){
sum[i][j]+=val;
}
}
}
int pre(int x,int y){
int ans=0;
for(int i=x;i;i^=lowbit(i)){
for(int j=y;j;j^=lowbit(j)){
ans+=sum[i][j];
}
}
return ans;
}
int query(int x1,int y1,int x2,int y2){return pre(x2,y2)-pre(x2,y1-1)-pre(x1-1,y2)+pre(x1-1,y1-1);}
}tr;
void solve(int l,int r,vector<int>nw){
if(nw.empty())return;
if(l==r){
for(int x:nw)ans[x]=l;
return;
}
const int mid=(l+r)>>1;
for(int i=l;i<=mid;i++)tr.update(p[i].x,p[i].y,1);
vector<int>ll,rr;
for(int x:nw){
int t=tr.query(q[x].x1,q[x].y1,q[x].x2,q[x].y2);
if(t>=q[x].k)ll.push_back(x);
else q[x].k-=t,rr.push_back(x);
}
for(int i=l;i<=mid;i++)tr.update(p[i].x,p[i].y,-1);
solve(l,mid,ll);solve(mid+1,r,rr);
}
signed main(){
ios::sync_with_stdio(false);
cin.tie(0),cout.tie(0);
cin>>n>>m;
for(int i=1;i<=n*n;i++)cin>>p[i].val,p[i].x=(i-1)/n+1,p[i].y=(i-1)%n+1;
for(int i=1;i<=n*n;i++)tmp[i]=p[i].val;
sort(tmp+1,tmp+1+n*n);
len=unique(tmp+1,tmp+1+n*n)-(tmp+1);
for(int i=1;i<=n*n;i++)p[i].val=lower_bound(tmp+1,tmp+1+len,p[i].val)-tmp;
sort(p+1,p+1+n*n);
for(int i=1;i<=m;i++)cin>>q[i].x1>>q[i].y1>>q[i].x2>>q[i].y2>>q[i].k;
vector<int>res;
for(int i=1;i<=m;i++)res.push_back(i);
solve(1,n*n,res);
for(int i=1;i<=m;i++)cout<<tmp[p[ans[i]].val]<<'\n';
return 0;
}
P7424 [THUPC 2017] 天天爱射击
板子。等价于给出一些区间,问这些区间中第 k 个子弹的编号。还是二分然后将权值赋为 \(0\) 或 \(1\) 即可。
#include <bits/stdc++.h>
using namespace std;
constexpr int N=2e5+5;
int n,m,ans[N],f[N],b[N];
struct Query{int l,r,k;}a[N];
struct BIT{
int sum[N];
void update(int p,int val){for(;p<=n;p+=p&-p)sum[p]+=val;}
int query(int p){int ans=0;for(;p;p^=p&-p)ans+=sum[p];return ans;}
}tr;
void solve(vector<int>res,int pl,int pr){
if(res.empty())return;
if(pl==pr){
for(int x:res)ans[x]=pl;
return;
}
const int mid=(pl+pr)>>1;
for(int i=pl;i<=mid;i++)tr.update(b[i],1);
vector<int>l,r;
for(int x:res){
int ans=tr.query(a[x].r)-tr.query(a[x].l-1);
if(a[x].k<=ans)l.push_back(x);
else a[x].k-=ans,r.push_back(x);
}
for(int i=pl;i<=mid;i++)tr.update(b[i],-1);
solve(l,pl,mid);solve(r,mid+1,pr);
}
signed main(){
ios::sync_with_stdio(false);
cin.tie(0),cout.tie(0);
cin>>n>>m;
for(int i=1;i<=n;i++)cin>>a[i].l>>a[i].r>>a[i].k;
for(int i=1;i<=m;i++)cin>>b[i];
vector<int>res(n);
for(int i=0;i<n;i++)res[i]=i+1;
solve(res,1,m+1);
for(int i=1;i<=n;i++)f[ans[i]]++;
for(int i=1;i<=m;i++)cout<<f[i]<<'\n';
return 0;
}
P3242 [HNOI2015] 接水果
考虑一条路径 \(x,y\) 被另一条路径 \(u,v\) 覆盖可以表示为以下情况:
- \(\text{lca}(x,y)=x\),此时 \(dfn_y\le dfn_u\le dfn_y+size_y-1\),\(dfn_v\le dfn_x\) 或 \(dfn_v>dfn_x+szie_x-1\)。
- \(\text{lca}(x,y)\not=x\),此时 \(dfn_y\le dfn_u\le dfn_y+size_y-1\),\(dfn_x\le dfn_v\le dfn_x+size_x-1\)。
这样就把路径覆盖转换成了 \(dfn\) 序之间的偏序关系,相当于二维平面内矩形第 \(k\) 小。于是扫描线加整体二分就做完了。
```plaintext
#include <bits/stdc++.h>
using namespace std;
constexpr int N=4e4+4;
int n,m,q,dfn[N],rk[N],size[N],top[N],s[N],dep[N],fa[N],cnt,tot;
vector<int>edge[N];
struct Node{int x,l,r,val;}a[N<<2];
struct Query{int x,y,k,id;}b[N];
int ans[N];
void dfs1(int p,int f){
dep[p]=dep[fa[p]=f]+1;
size[p]=1;
for(int x:edge[p]){
if(x==f)continue;
dfs1(x,p);
size[p]+=size[x];
if(size[x]>size[s[p]])s[p]=x;
}
}
void dfs2(int p,int t){
top[p]=t;
dfn[rk[p]=++cnt]=p;
if(s[p])dfs2(s[p],t);
for(int x:edge[p]){
if(x==fa[p]||x==s[p])continue;
dfs2(x,x);
}
}
int lca(int x,int y,bool t=0){
int _y=y;
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]])swap(x,y);
if(fa[top[x]]==_y&&t)return top[x];
x=fa[top[x]];
}
if(dep[x]>dep[y])swap(x,y);
if(t)return s[x];
return x;
}
struct BIT{
int sum[N];
void update(int p,int val){for(;p<=n;p+=p&-p)sum[p]+=val;}
int query(int p){int ans=0;for(;p;p^=p&-p)ans+=sum[p];return ans;}
}tr;
void solve(vector<int>res,vector<int>c,int pl,int pr){
if(res.empty())return;
if(pl==pr){
for(int x:res)ans[b[x].id]=pl;
return;
}
const int mid=(pl+pr)>>1;
vector<int>l,r,lc,rc;
for(int x:c){
if(abs(a[x].val)<=mid)lc.push_back(x);
else rc.push_back(x);
}
sort(lc.begin(),lc.end(),[](int x,int y){return a[x].x<a[y].x;});
sort(res.begin(),res.end(),[](int x,int y){return b[x].x<b[y].x;});
int p=0;
for(int x:res){
while(p<lc.size()&&a[lc[p]].x<=b[x].x){
// cout<<a[lc[p]].l<<' '<<a[lc[p]].r<<' '<<(a[lc[p]].val<0?-1:1)<<'\n';
tr.update(a[lc[p]].l,1*(a[lc[p]].val<0?-1:1));
tr.update(a[lc[p]].r+1,-1*(a[lc[p]].val<0?-1:1));
++p;
}
int ans=tr.query(b[x].y);
// cout<<x<<' '<<b[x].x<<' '<<b[x].y<<' '<<pl<<' '<<pr<<' '<<ans<<'\n';
if(b[x].k<=ans)l.push_back(x);
else b[x].k-=ans,r.push_back(x);
}
for(int i=0;i<p;i++){
tr.update(a[lc[i]].l,-1*(a[lc[i]].val<0?-1:1));
tr.update(a[lc[i]].r+1,1*(a[lc[i]].val<0?-1:1));
}
solve(l,lc,pl,mid);solve(r,rc,mid+1,pr);
}
int tmp[N];
signed main(){
ios::sync_with_stdio(false);
cin.tie(0),cout.tie(0);
cin>>n>>m>>q;
for(int i=1;i<n;i++){
int x,y;cin>>x>>y;
edge[x].push_back(y);
edge[y].push_back(x);
}
dfs1(5,0);dfs2(5,5);
for(int i=1;i<=m;i++){
int x,y;cin>>x>>y>>tmp[i];
if(rk[x]<rk[y])swap(x,y);
if(lca(x,y)==y){
int z=lca(x,y,1);
if(rk[z]>1){
a[++tot]={rk[x],1,rk[z]-1,tmp[i]};
a[++tot]={rk[x]+size[x],1,rk[z]-1,-tmp[i]};
}
if(rk[z]+size[z]-1<n){
a[++tot]={rk[z]+size[z],rk[x],rk[x]+size[x]-1,tmp[i]};
a[++tot]={n+1,rk[x],rk[x]+size[x]-1,-tmp[i]};
}
}
else{
a[++tot]={rk[x],rk[y],rk[y]+size[y]-1,tmp[i]};
a[++tot]={rk[x]+size[x],rk[y],rk[y]+size[y]-1,-tmp[i]};
}
}
sort(tmp+1,tmp+m+1);
int len=unique(tmp+1,tmp+m+1)-tmp-1;
for(int i=1;i<=tot;i++)a[i].val=(lower_bound(tmp+1,tmp+len+1,abs(a[i].val))-tmp)*(a[i].val<0?-1:1);
// for(int i=1;i<=tot;i++)cout<<a[i].x<<' '<<a[i].l<<' '<<a[i].r<<' '<<a[i].val<<'\n';
for(int i=1;i<=q;i++){
cin>>b[i].x>>b[i].y>>b[i].k;
b[i].id=i;b[i].x=rk[b[i].x],b[i].y=rk[b[i].y];
if(b[i].x<b[i].y)swap(b[i].x,b[i].y);
}
vector<int>res,c;
for(int i=1;i<=q;i++)res.push_back(i);
for(int i=1;i<=tot;i++)c.push_back(i);
solve(res,c,1,len);
for(int i=1;i<=m;i++)cout<<tmp[ans[i]]<<'\n';
return 0;
}