整体二分学习笔记

引入

整体二分可以一次性地解决多次询问的二分,一般配合数据结构使用。

可以使用整体二分解决的题目需要满足以下性质:

  1. 询问的答案具有可二分性

  2. 修改对判定答案的贡献互相独立,修改之间互不影响效果

  3. 修改如果对判定答案有贡献,则贡献为一确定的与判定标准无关的值

  4. 贡献满足交换律,结合律,具有可加性

  5. 题目允许使用离线算法

——许昊然《浅谈数据结构题几个非经典解法》

先从普通的二分说起。

例1:查询全局第 \(k\)

显然,这可以通过简单的排序在 \(O(n\log n)\) 时间内解决。如果使用二分,则可以开桶维护,然后二分答案,每次查询有多少个数比当前答案要大来进行 \(check\)

例2:多次查询全局第 \(k\)

这里排序的复杂度是 \(O(n\log n+T)\) 的。如果每次询问都进行一次二分答案,总复杂度就是 \(O(n\log n+T\log ans)\) 的。

例3:多次查询区间第 \(k\) 大(主席树模板)

因为要维护区间信息,所以这题排序无法解决。同样地,直接二分会使 \(check\) 的复杂度爆炸。

于是我们可以使用整体二分来优化直接二分。

简介

整体二分就是一个把所有询问组成的集合按照答案大小划分的过程。

假设所有询问的答案为 \(Q_1,Q_2,Q_3,...,Q_n\),答案范围为 \([l,r]\)

二分答案,当前答案 \(mid=\lfloor\frac{l+r}{2}\rfloor\)

对于所有 \(Q_i,i\in[1,n]\) 进行 \(check\),将 \(\le mid\) 的答案归为一组,\(>mid\) 的答案归为另一组。

然后对两组 \([l,mid]\)\([mid+1,r]\) 递归地进行处理即可。

递归树高 \(\log(n)\),每层对于所有询问进行一次 \(check\)。设 \(check\) 所有询问的复杂度为 \(O(T)\),则总复杂度 \(O(T\log(n))\)。在例3中,\(O(T)=O(n\log(n))\),于是总复杂度为 \(O(n\log^2(n))\)

例题

洛谷P1527 [国家集训队] 矩阵乘法

考虑对每次询问二分答案。对于小于当前答案的数将其标记为 \(1\),大于等于的标记为 \(0\),则可以用二维树状数组统计矩阵内 \(1\) 的个数,即矩阵和。将所有询问整体二分即可。时间复杂度 \(O((N^2+Q)*\log ^3N)\)

#include<bits/stdc++.h>
using namespace std;
constexpr int N=5e2+2,M=6e4+4;
int n,m,tmp[N*N],len,ans[M];
struct Mat{int x,y,val;}p[N*N];
constexpr inline bool operator<(Mat x,Mat y){return x.val<y.val;}
struct Query{int x1,x2,y1,y2,k;}q[M];
struct BIT{
    #define lowbit(x) (x&-x)
    int sum[N][N];
    void update(int x,int y,int val){
        for(int i=x;i<=n;i+=lowbit(i)){
            for(int j=y;j<=n;j+=lowbit(j)){
                sum[i][j]+=val;
            }
        }
    }
    int pre(int x,int y){
        int ans=0;
        for(int i=x;i;i^=lowbit(i)){
            for(int j=y;j;j^=lowbit(j)){
                ans+=sum[i][j];
            }
        }
        return ans;
    }
    int query(int x1,int y1,int x2,int y2){return pre(x2,y2)-pre(x2,y1-1)-pre(x1-1,y2)+pre(x1-1,y1-1);}
}tr;
void solve(int l,int r,vector<int>nw){
    if(nw.empty())return;
    if(l==r){
        for(int x:nw)ans[x]=l;
        return;
    }
    const int mid=(l+r)>>1;
    for(int i=l;i<=mid;i++)tr.update(p[i].x,p[i].y,1);
    vector<int>ll,rr;
    for(int x:nw){
        int t=tr.query(q[x].x1,q[x].y1,q[x].x2,q[x].y2);
        if(t>=q[x].k)ll.push_back(x);
        else q[x].k-=t,rr.push_back(x);
    }
    for(int i=l;i<=mid;i++)tr.update(p[i].x,p[i].y,-1);
    solve(l,mid,ll);solve(mid+1,r,rr);
}
signed main(){
    ios::sync_with_stdio(false);
    cin.tie(0),cout.tie(0);
    cin>>n>>m;
    for(int i=1;i<=n*n;i++)cin>>p[i].val,p[i].x=(i-1)/n+1,p[i].y=(i-1)%n+1;
    for(int i=1;i<=n*n;i++)tmp[i]=p[i].val;
    sort(tmp+1,tmp+1+n*n);
    len=unique(tmp+1,tmp+1+n*n)-(tmp+1);
    for(int i=1;i<=n*n;i++)p[i].val=lower_bound(tmp+1,tmp+1+len,p[i].val)-tmp;
    sort(p+1,p+1+n*n);
    for(int i=1;i<=m;i++)cin>>q[i].x1>>q[i].y1>>q[i].x2>>q[i].y2>>q[i].k;
    vector<int>res;
    for(int i=1;i<=m;i++)res.push_back(i);
    solve(1,n*n,res);
    for(int i=1;i<=m;i++)cout<<tmp[p[ans[i]].val]<<'\n';
    return 0;
}

P7424 [THUPC 2017] 天天爱射击

板子。等价于给出一些区间,问这些区间中第 k 个子弹的编号。还是二分然后将权值赋为 \(0\)\(1\) 即可。

#include <bits/stdc++.h>
using namespace std;
constexpr int N=2e5+5;
int n,m,ans[N],f[N],b[N];
struct Query{int l,r,k;}a[N];
struct BIT{
    int sum[N];
    void update(int p,int val){for(;p<=n;p+=p&-p)sum[p]+=val;}
    int query(int p){int ans=0;for(;p;p^=p&-p)ans+=sum[p];return ans;}
}tr;
void solve(vector<int>res,int pl,int pr){
    if(res.empty())return;
    if(pl==pr){
        for(int x:res)ans[x]=pl;
        return;
    }
    const int mid=(pl+pr)>>1;
    for(int i=pl;i<=mid;i++)tr.update(b[i],1);
    vector<int>l,r;
    for(int x:res){
        int ans=tr.query(a[x].r)-tr.query(a[x].l-1);
        if(a[x].k<=ans)l.push_back(x);
        else a[x].k-=ans,r.push_back(x);
    }
    for(int i=pl;i<=mid;i++)tr.update(b[i],-1);
    solve(l,pl,mid);solve(r,mid+1,pr);
}
signed main(){
    ios::sync_with_stdio(false);
    cin.tie(0),cout.tie(0);
    cin>>n>>m;
    for(int i=1;i<=n;i++)cin>>a[i].l>>a[i].r>>a[i].k;
    for(int i=1;i<=m;i++)cin>>b[i];
    vector<int>res(n);
    for(int i=0;i<n;i++)res[i]=i+1;
    solve(res,1,m+1);
    for(int i=1;i<=n;i++)f[ans[i]]++;
    for(int i=1;i<=m;i++)cout<<f[i]<<'\n';
    return 0;
}

P3242 [HNOI2015] 接水果

考虑一条路径 \(x,y\) 被另一条路径 \(u,v\) 覆盖可以表示为以下情况:

  1. \(\text{lca}(x,y)=x\),此时 \(dfn_y\le dfn_u\le dfn_y+size_y-1\)\(dfn_v\le dfn_x\)\(dfn_v>dfn_x+szie_x-1\)
  2. \(\text{lca}(x,y)\not=x\),此时 \(dfn_y\le dfn_u\le dfn_y+size_y-1\)\(dfn_x\le dfn_v\le dfn_x+size_x-1\)

这样就把路径覆盖转换成了 \(dfn\) 序之间的偏序关系,相当于二维平面内矩形第 \(k\) 小。于是扫描线加整体二分就做完了。

```plaintext
#include <bits/stdc++.h>
using namespace std;
constexpr int N=4e4+4;
int n,m,q,dfn[N],rk[N],size[N],top[N],s[N],dep[N],fa[N],cnt,tot;
vector<int>edge[N];
struct Node{int x,l,r,val;}a[N<<2];
struct Query{int x,y,k,id;}b[N];
int ans[N];
void dfs1(int p,int f){
    dep[p]=dep[fa[p]=f]+1;
    size[p]=1;
    for(int x:edge[p]){
        if(x==f)continue;
        dfs1(x,p);
        size[p]+=size[x];
        if(size[x]>size[s[p]])s[p]=x;
    }
}
void dfs2(int p,int t){
    top[p]=t;
    dfn[rk[p]=++cnt]=p;
    if(s[p])dfs2(s[p],t);
    for(int x:edge[p]){
        if(x==fa[p]||x==s[p])continue;
        dfs2(x,x);
    }
}
int lca(int x,int y,bool t=0){
    int _y=y;
    while(top[x]!=top[y]){
        if(dep[top[x]]<dep[top[y]])swap(x,y);
        if(fa[top[x]]==_y&&t)return top[x];
        x=fa[top[x]];
    }
    if(dep[x]>dep[y])swap(x,y);
    if(t)return s[x];
    return x;
}
struct BIT{
    int sum[N];
    void update(int p,int val){for(;p<=n;p+=p&-p)sum[p]+=val;}
    int query(int p){int ans=0;for(;p;p^=p&-p)ans+=sum[p];return ans;}
}tr;
void solve(vector<int>res,vector<int>c,int pl,int pr){
    if(res.empty())return;
    if(pl==pr){
        for(int x:res)ans[b[x].id]=pl;
        return;
    }
    const int mid=(pl+pr)>>1;
    vector<int>l,r,lc,rc;
    for(int x:c){
        if(abs(a[x].val)<=mid)lc.push_back(x);
        else rc.push_back(x);
    }
    sort(lc.begin(),lc.end(),[](int x,int y){return a[x].x<a[y].x;});
    sort(res.begin(),res.end(),[](int x,int y){return b[x].x<b[y].x;});
    int p=0;
    for(int x:res){
        while(p<lc.size()&&a[lc[p]].x<=b[x].x){
            // cout<<a[lc[p]].l<<' '<<a[lc[p]].r<<' '<<(a[lc[p]].val<0?-1:1)<<'\n';
            tr.update(a[lc[p]].l,1*(a[lc[p]].val<0?-1:1));
            tr.update(a[lc[p]].r+1,-1*(a[lc[p]].val<0?-1:1));
            ++p;
        }
        int ans=tr.query(b[x].y);
        // cout<<x<<' '<<b[x].x<<' '<<b[x].y<<' '<<pl<<' '<<pr<<' '<<ans<<'\n';
        if(b[x].k<=ans)l.push_back(x);
        else b[x].k-=ans,r.push_back(x);
    }
    for(int i=0;i<p;i++){
        tr.update(a[lc[i]].l,-1*(a[lc[i]].val<0?-1:1));
        tr.update(a[lc[i]].r+1,1*(a[lc[i]].val<0?-1:1));
    }
    solve(l,lc,pl,mid);solve(r,rc,mid+1,pr);
}
int tmp[N];
signed main(){
    ios::sync_with_stdio(false);
    cin.tie(0),cout.tie(0);
    cin>>n>>m>>q;
    for(int i=1;i<n;i++){
        int x,y;cin>>x>>y;
        edge[x].push_back(y);
        edge[y].push_back(x);
    }
    dfs1(5,0);dfs2(5,5);
    for(int i=1;i<=m;i++){
        int x,y;cin>>x>>y>>tmp[i];
        if(rk[x]<rk[y])swap(x,y);
        if(lca(x,y)==y){
            int z=lca(x,y,1);
            if(rk[z]>1){
                a[++tot]={rk[x],1,rk[z]-1,tmp[i]};
                a[++tot]={rk[x]+size[x],1,rk[z]-1,-tmp[i]};
            }
            if(rk[z]+size[z]-1<n){
                a[++tot]={rk[z]+size[z],rk[x],rk[x]+size[x]-1,tmp[i]};
                a[++tot]={n+1,rk[x],rk[x]+size[x]-1,-tmp[i]};
            }
        }
        else{
            a[++tot]={rk[x],rk[y],rk[y]+size[y]-1,tmp[i]};
            a[++tot]={rk[x]+size[x],rk[y],rk[y]+size[y]-1,-tmp[i]};
        }
    }
    sort(tmp+1,tmp+m+1);
    int len=unique(tmp+1,tmp+m+1)-tmp-1;
    for(int i=1;i<=tot;i++)a[i].val=(lower_bound(tmp+1,tmp+len+1,abs(a[i].val))-tmp)*(a[i].val<0?-1:1);
    // for(int i=1;i<=tot;i++)cout<<a[i].x<<' '<<a[i].l<<' '<<a[i].r<<' '<<a[i].val<<'\n';
    for(int i=1;i<=q;i++){
        cin>>b[i].x>>b[i].y>>b[i].k;
        b[i].id=i;b[i].x=rk[b[i].x],b[i].y=rk[b[i].y];
        if(b[i].x<b[i].y)swap(b[i].x,b[i].y);
    }
    vector<int>res,c;
    for(int i=1;i<=q;i++)res.push_back(i);
    for(int i=1;i<=tot;i++)c.push_back(i);
    solve(res,c,1,len);
    for(int i=1;i<=m;i++)cout<<tmp[ans[i]]<<'\n';
    return 0;
}
posted @ 2025-04-28 21:37  Tachanka233  阅读(14)  评论(0)    收藏  举报