[BZOJ2738]矩阵乘法 整体二分+树状数组

题意:

给你一个N*N的矩阵,不用算矩阵乘法,但是每次询问一个子矩形的第K小数。

输入:

第一行两个数N,Q,表示矩阵大小和询问组数;
接下来N行N列一共N*N个数,表示这个矩阵;
再接下来Q行每行5个数描述一个询问:x1,y1,x2,y2,k表示找到以(x1,y1)为左上角、以(x2,y2)为右下角的子矩形中的第K小数。

输出:

对于每组询问输出第K小的数。

Input:

2 2
2 1
3 4
1 2 1 2 1
1 1 2 2 3

Output:

1
3

思路:

如果求一个区间第\(k\)小,可以考虑二分,将\(<=mid\)的值赋成1,用树状数组维护求有多少个数小于等于\(mid\),在和\(k\)进行比较。

如果求一个矩阵中第\(k\)小,思路同上,这时候可以考虑用二维树状数组进行维护。

那么求多个矩阵中的,每都一次二分复杂度会炸,所以可以考虑用整体二分√

将当前\(k<=sum\)的放到左区间,\(k>=sum\)的放到右区间继续二分即可。

#include<bits/stdc++.h>
#define M 505
#define N 60005
#define P 330000
using namespace std;
int mx=0,mi=2e9,n,q,a[M][M],ans[N],tot;
int lowbit(int x) {
	return x&-x;
}
struct node {
	int lx,ly,rx,ry,k,id;//对于 加入 一个点来说 lx ly表示坐标 k表示大小  id=0表示是加点
} Q[P],B[P];
struct Tree {//二维树状数组用来记录矩阵中的和
	int cnt[M][M];
	void add(int x,int y,int v) {
		while(x<=n) {
			int j=y;
			while(j<=n)cnt[x][j]+=v,j+=lowbit(j);
			x+=lowbit(x);
		}
	}
	int sum(int x,int y) {
		int res=0;
		while(x) {
			int j=y;
			while(j)res+=cnt[x][j],j-=lowbit(j);
			x-=lowbit(x);
		}
		return res;
	}
} T;
int get(int lx,int ly,int rx,int ry) {//求矩阵中的和,求法和普通的数组相似
	return T.sum(rx,ry)-T.sum(rx,ly-1)-T.sum(lx-1,ry)+T.sum(lx-1,ly-1);
}
void erfen(int l,int r,int L,int R) {
	if(l>r||L>R)return;
	int l1=L,r1=R,mid=(l+r)>>1;
	for(int i=L; i<=R; i++) {
		if(Q[i].id!=0)continue;
		if(Q[i].k<=mid)B[l1++]=Q[i],T.add(Q[i].lx,Q[i].ly,1);//分到左区间
		else B[r1--]=Q[i];
	}
	for(int i=L; i<=R; i++) {
		if(Q[i].id==0)continue;
		int sum=get(Q[i].lx,Q[i].ly,Q[i].rx,Q[i].ry);
		if(sum>=Q[i].k)ans[Q[i].id]=mid,B[l1++]=Q[i];//更新答案
		else Q[i].k-=sum,B[r1--]=Q[i];//将右区间的减去这部分贡献
	}
	for(int i=L; i<l1; i++)if(B[i].id==0)T.add(B[i].lx,B[i].ly,-1);//回撤
	for(int i=L; i<=R; i++)Q[i]=B[i];
	erfen(l,mid-1,L,l1-1);
	erfen(mid+1,r,r1+1,R);
}
int main() {
	scanf("%d%d",&n,&q);
	tot=q;
	for(int i=1; i<=n; i++)for(int x,j=1; j<=n; j++)scanf("%d",&x),Q[++tot]=(node)<%i,j,0,0,x,0%>,mi=min(mi,x),mx=max(mx,x);
	for(int i=1; i<=q; i++)scanf("%d%d%d%d%d",&Q[i].lx,&Q[i].ly,&Q[i].rx,&Q[i].ry,&Q[i].k),Q[i].id=i;
	erfen(mi,mx,1,tot);
	for(int i=1; i<=q; i++)printf("%d\n",ans[i]);
	return 0;
}
posted @ 2019-07-08 10:12  季芊月  阅读(123)  评论(0编辑  收藏  举报