【bzoj2738】矩阵乘法 整体二分+二维树状数组

题目描述

给你一个N*N的矩阵,不用算矩阵乘法,但是每次询问一个子矩形的第K小数。

输入

第一行两个数N,Q,表示矩阵大小和询问组数;
接下来N行N列一共N*N个数,表示这个矩阵;
再接下来Q行每行5个数描述一个询问:x1,y1,x2,y2,k表示找到以(x1,y1)为左上角、以(x2,y2)为右下角的子矩形中的第K小数。

输出

对于每组询问输出第K小的数。

样例输入

2 2
2 1
3 4
1 2 1 2 1
1 1 2 2 3

样例输出

1
3


题解

整体二分+二维树状数组

题目答案具有很明显的二分性质,所以可以考虑把询问离线,整体二分。(其实是看了题解才知道是整体二分的,一开始想了权值线段树套线段树套线段树)

整体二分是什么?和标准的二分几乎是一样的,只不过在判定的时候是把所有答案属于这个区间的询问都一一判定,根据是否满足条件,把询问分成答案属于[l,mid]和[mid+1,r]中的,然后对这两个子区间分别处理。最后当l=r时说明所有属于这个区间内的询问的答案都是l。

那么对于本题,我们可以按照矩阵内的数从大到小排序。然后把所有询问读进来,整体二分。

用solve(b,e,l,r)表示要解决询问区间在[b,e]内,点区间在[l,r]内的答案,那么当l=r时,ans[b~e]=v[l]。

当l≠r时,令mid=(l+r)/2。此时需要判定的就是是否有大于等于k个数在[l,mid]之间。

所以把[l,mid]内对应位置的权值+1,要求的就是矩形内的权值之和,可以使用二维树状数组来解决。

如果矩形内的数的个数大于等于k,则答案在[l,mid]内,把它放到对应的区间中;否则答案在[mid+1,r]中,并且把k值减去数的个数(因为这些数一定是前k小的)。

最后把二维树状数组清空(不能使用memset,需要动态清零),然后递归处理左右区间即可。

时间复杂度$O(q\log^3n)$

#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
struct POINT
{
	int x , y , v;
	bool operator<(const POINT a)const {return v < a.v;}
}a[250010];
struct QUERY
{
	int x1 , y1 , x2 , y2 , k , id;
}q[60010] , t[60010];
int f[510][510] , n , ans[60010];
void add(int x , int y , int a)
{
	int i , j;
	for(i = x ; i <= n ; i += i & -i)
		for(j = y ; j <= n ; j += j & -j)
			f[i][j] += a;
}
int query(int x , int y)
{
	int i , j , ans = 0;
	for(i = x ; i ; i -= i & -i)
		for(j = y ; j ; j -= j & -j)
			ans += f[i][j];
	return ans;
}
void solve(int b , int e , int l , int r)
{
	if(b > e) return;
	int mid = (l + r) >> 1 , i , tl = b , tr = e , c;
	if(l == r)
	{
		for(i = b ; i <= e ; i ++ ) ans[q[i].id] = a[l].v;
		return;
	}
	for(i = l ; i <= mid ; i ++ ) add(a[i].x , a[i].y , 1);
	for(i = b ; i <= e ; i ++ )
	{
		c = query(q[i].x2 , q[i].y2) - query(q[i].x1 - 1 , q[i].y2) - query(q[i].x2 , q[i].y1 - 1) + query(q[i].x1 - 1 , q[i].y1 - 1);
		if(c >= q[i].k) t[tl ++ ] = q[i];
		else q[i].k -= c , t[tr -- ] = q[i];
	}
	for(i = b ; i <= e ; i ++ ) q[i] = t[i];
	for(i = l ; i <= mid ; i ++ ) add(a[i].x , a[i].y , -1);
	solve(b , tr , l , mid) , solve(tl , e , mid + 1 , r);
}
int main()
{
	int m , i , j;
	scanf("%d%d" , &n , &m);
	for(i = 1 ; i <= n ; i ++ )
		for(j = 1 ; j <= n ; j ++ )
			scanf("%d" , &a[(i - 1) * n + j].v) , a[(i - 1) * n + j].x = i , a[(i - 1) * n + j].y = j;
	sort(a + 1 , a + n * n + 1);
	for(i = 1 ; i <= m ; i ++ ) scanf("%d%d%d%d%d" , &q[i].x1 , &q[i].y1 , &q[i].x2 , &q[i].y2 , &q[i].k) , q[i].id = i;
	solve(1 , m , 1 , n * n);
	for(i = 1 ; i <= m ; i ++ ) printf("%d\n" , ans[i]);
	return 0;
}

 

 

posted @ 2017-08-23 09:28  GXZlegend  阅读(372)  评论(0编辑  收藏  举报