HDU 4331 Image Recognition

本题题目大意在一个01方阵中找出四条边全都是1的正方形的个数,对于正方形内部则没有要求。

一个直观的想法是首先用N^2的时间预处理出每一个是1的点向上下左右四个方向能够延伸的1的最大长度,记为四个数组l, r, u, d。然后我们观察到正方形有一个特征是同一对角线上的两个顶点在原方阵的同一条对角线上。于是我们可以想到枚举原来方阵的每条对角线,然后我们对于每条对角线枚举对角线上所有是1的点i,那么我们可以发现可能和i构成正方形的点应该在该对角线的 [i, i + min(r[i], d[i]) – 1] 闭区间内, 而在这个区间内的点 j 只要满足 j – i + 1 <= min(l[j], u[j]) 也就是满足j – min(l[j], u[j]) + 1 <= i,这样的 (i, j) 就能构成一个正方形。也就是说对于每条对角线,我们可以构造一个数组 a, 使得a[i] = i – min(l[i], u[i]) + 1


然后对这个数组有若干次查询,每次查询的是区间 [i, i + min(r[i], d[i]) – 1]内有多少个数满足 a[j] <= i,所有这些问题答案的和就是该问题的结果。对于这个问题,我们可以通过离线算法,先保存所有查询的区间端点,并对所有端点排序。然后使用扫描线算法,如果扫描到的是第i次查询的左端点,就让当前结果减去当前扫描过的数中 <= i的个数,如果扫描到的是第i次查询的有短点,则让当前结果加上当前扫描过的数中 <= i的个数,最后所有结果相加即可。


维护当前数出现的个数可以使用树状数组。这样对于每条对角线求结果的复杂度为O(nlogn),算法总的复杂度为O(n^2logn)。

View Code
View Code 
#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<algorithm>
#include<cmath>
#include<queue>
#include<set>
#include<cstring>
#include<vector>
#include<string>
#define LL long long
using namespace std;
int map[1024][1024],l[1024][1024],u[1024][1024],d[1024][1024],r[1024][1024],a[2024],c[2024];
class Node{
public:
      bool left;
      int x,id;    
}p[2024];
bool cmp( Node a, Node b ){
    if( a.x == b.x ) return a.left;
    return a.x < b.x;    
}
int lowbit( int x ){
    return x&(-x);    
}
void Updata( int x, int n ){
    for( int i = x ; i <= n ; i += lowbit(i) )
          c[i] ++;    
}
int Query( int x ){
    int ans = 0;
    for( int i = x; i > 0 ; i -= lowbit(i) )
         ans += c[i];
    return ans;    
}
int res( int n, int m ){
    int ans = 0;
    memset( c , 0 , sizeof( c ) );
    sort( p , p + m , cmp );
    for( int i = 0 ; i < m ; i ++ )
         if( p[i].left ) {
                ans -= Query( p[i].id );
                Updata( a[p[i].x] ,n );
         }
         else ans +=Query( p[i].id );
//    printf( "__%d\n",ans );
    return ans;
}
int Solve( int n )
{
    int ans=0;
    for( int i =1 ; i <= n ; i ++ ){
        int m = 0;
        for( int j = 1 ; j <= i ; j ++ ){
                int x=n-i+j,y=j;         
                if( map[x][y] == 1 ){
                    a[y] = y - min( l[x][y],u[x][y] ) + 1;
                    p[m].left = true;p[m].id=y;p[m].x=y,p[m].id=y;
                    m++;
                    p[m].left=false;p[m].x=y+min( r[x][y],d[x][y] )-1;p[m].id=y;
                    m++;
                    }        
            }   
        ans += res( n ,m ); 
//        printf( "ans=%d\n",ans );   
      }    
      for( int i =2 ; i <= n ; i ++ ){
        int m = 0;
        for( int j = 1 ; j <= n - i + 1 ; j ++ ){
                  int x=j,y=i+j-1;             
                if( map[x][y] == 1 ){
                    a[y] = y - min( l[x][y],u[x][y] ) + 1;
                    p[m].left = true;p[m].id=y;p[m].x=y;
                    m++;
                    p[m].left=false;p[m].x=y+min( r[x][y],d[x][y] )-1;p[m].id=y;
                    m++;
                    }        
            }  
        ans += res( n ,m );   
//        printf( "ans=%d\n",ans ); 
      }    
      return ans;
}
int main(  ){
    int T,n;
    while( scanf( "%d",&T )==1 ){
        for( int cas = 1 ; cas <= T ;cas++ ){
            memset( u , 0 , sizeof( u ) );
            memset( d , 0 , sizeof( d ) );
            memset( l , 0 , sizeof( l ) );
            memset( r , 0 , sizeof( r ) );
             scanf( "%d",&n );
             LL cnt = 0,ans=0;
             for( int i = 1 ; i <= n ; i ++ )
                  for( int j = 1; j <= n ; j ++ ){
                       scanf( "%d",&map[i][j] ); 
                       if( map[i][j] == 0 ) u[i][j] = l[i][j] = 0;
                      else{
                        u[i][j] = u[i-1][j]  + 1;
                        l[i][j] = l[i][j-1] + 1;
                     }
                  }
            for( int i = n ; i >0  ; i -- ){
                 for( int j = n ;j > 0 ; j -- ){
                        if( map[i][j] == 0 ) d[i][j] = r[i][j] = 0;
                        else{
                        d[i][j] = d[i+1][j] + 1;
                        r[i][j] = r[i][j+1] + 1;
                       }
                     }
                }
            printf( "Case %d: %d\n",cas,Solve( n ) );
       }
   }
    //system( "pause" );
    return 0;
}

 

 



posted @ 2012-08-03 20:19  wutaoKeen  阅读(346)  评论(0)    收藏  举报