HDU 4331 Image Recognition

这个题好纠结。。。比赛没思路,回头又看得出题报告。最后写出来的数状数组跑了1.5s。。。数状数组的思想还算好理解,就是对角线的控制上。。纠结了。

一个直观的想法是首先用N^2的时间预处理出每一个是1的点向上下左右四个方向能够延伸的1的最大长度,记为四个数组l, r, u, d。然后我们观察到正方形有一个特征是同一对角线上的两个顶点在原方阵的同一条对角线上。于是我们可以想到枚举原来方阵的每条对角线,然后我们对于每条对角线枚举对角线上所有是1的点i,那么我们可以发现可能和i构成正方形的点应该在该对角线的 [i, i + min(r[i], d[i]) – 1] 闭区间内, 而在这个区间内的点 j 只要满足 j – i + 1 <= min(l[j], u[j]) 也就是满足
j – min(l[j], u[j]) + 1 <= i
,这样的 (i, j) 就能构成一个正方形。

也就是说对于每条对角线,我们可以构造一个数组 a, 使得
a[i] = i – min(l[i], u[i]) + 1

然后对这个数组有若干次查询,每次查询的是区间 [i, i + min(r[i], d[i]) – 1]内有多少个数满足 a[j] <= i,所有这些问题答案的和就是该问题的结果。对于这个问题,我们可以通过离线算法,先保存所有查询的区间端点,并对所有端点排序。然后使用扫描线算法,如果扫描到的是第i次查询的左端点,就让当前结果减去当前扫描过的数中 <= i的个数,如果扫描到的是第i次查询的有短点,则让当前结果加上当前扫描过的数中 <= i的个数,最后所有结果相加即可。维护当前数出现的个数可以使用树状数组。这样对于每条对角线求结果的复杂度为O(nlogn),算法总的复杂度为O(n^2logn)

 

View Code
#include <iostream>
#include <cstdio>
#include <cmath>
#include <vector>
#include <cstring>
#include <algorithm>
#include <string>
#include <set>
#include <ctime>
#include <queue>
#include <map>
#include <sstream>

#define CL(arr, val)    memset(arr, val, sizeof(arr))
#define REP(i, n)       for((i) = 0; (i) < (n); ++(i))
#define FOR(i, l, h)    for((i) = (l); (i) <= (h); ++(i))
#define FORD(i, h, l)   for((i) = (h); (i) >= (l); --(i))
#define L(x)    (x) << 1
#define R(x)    (x) << 1 | 1
#define MID(l, r)   (l + r) >> 1
#define Min(x, y)   x < y ? x : y
#define Max(x, y)   x < y ? y : x
#define E(x)    (1 << (x))
#define iabs(x)  ((x) > 0 ? (x) : -(x))

typedef long long LL;
const double eps = 1e-8;
const int inf = ~0u>>2;

using namespace std;

const int N = 1024;

int mat[N][N];
int u[N][N], l[N][N], r[N][N], d[N][N];
int c[N], a[N];
int n;

struct point {
    int x, id;
    bool isleft;

    bool operator < (const point tmp) const {
        if(this->x == tmp.x)  return this->isleft;
        return x < tmp.x;
    }
} p[N<<1];

int lowbit(int i) {
    return i&(-i);
}

void add(int p) {
    while(p <= n) {
        c[p]++;
        p += lowbit(p);
    }
}

int sum(int x) {
    int res = 0;
    while(x > 0){
        res += c[x];
        x -= lowbit(x);
    }
    return res;
}

void init() {
    CL(u, 0); CL(d, 0);
    CL(l, 0); CL(r, 0);
    int i, j;

    for(i = 1; i <= n; ++i) {
        for(j = 1; j <= n; ++j) {
            if(mat[i][j] == 0)  u[i][j] = 0, l[i][j] = 0;
            else {
                u[i][j] = u[i-1][j] + 1;
                l[i][j] = l[i][j-1] + 1;
            }
        }
    }

    for(i = n; i >= 1; --i) {
        for(j = n; j >= 1; --j) {
            if(mat[i][j] == 0)  r[i][j] = 0, d[i][j] = 0;
            else {
                d[i][j] = d[i+1][j] + 1;
                r[i][j] = r[i][j+1] + 1;
            }
        }
    }
}

int Count(int m) {
    sort(p, p + m);
    CL(c, 0);
    int i, ans = 0;

    for(i = 0; i < m; ++i) {
        //printf("%d %d %d\n", p[i].isleft, p[i].x, p[i].id);
        if(p[i].isleft) {
            ans -= sum(p[i].id);
            add(a[p[i].id]);
        } else ans += sum(p[i].id);
    }
    return ans;
}

int solve() {
    int i, j, x, y;
    int res = 0, cnt = 0;

    for(i = 1; i <= n; ++i) {
        cnt = 0;
        for(j = 1; j <= i; ++j) {
            x = n - i + j; y = j;
            if(mat[x][y]) {
                a[y] = y - min(l[x][y], u[x][y]) + 1;
                p[cnt].isleft = true; p[cnt].x = y; p[cnt].id = y; cnt++;
                p[cnt].isleft = false; p[cnt].x = y + min(r[x][y], d[x][y]) - 1; p[cnt].id = y; cnt++;
            }
        }
        res += Count(cnt);
    }
    for(i = 2; i <= n; ++i) {
        cnt = 0;
        for(j = 1; j <= n - i + 1; ++j) {
            x = j, y = i + j - 1;
            if(mat[x][y]) {
                a[y] = y - min(l[x][y], u[x][y]) + 1;
                p[cnt].isleft = true; p[cnt].x = y; p[cnt].id = y; cnt++;
                p[cnt].isleft = false; p[cnt].x = y + min(r[x][y], d[x][y]) - 1; p[cnt].id = y; cnt++;
            }
        }
        res += Count(cnt);
    }
    return res;
}

int main() {
    //freopen("data.in", "r", stdin);

    int i, j, t, cas = 0;
    scanf("%d", &t);
    while(t--) {
        CL(mat, 0);
        scanf("%d", &n);

        for(i = 1; i <= n; ++i) {
            for(j = 1; j <= n; ++j) {
                scanf("%d", &mat[i][j]);
            }
        }

        init();
        int ans = solve();
        printf("Case %d: %d\n", ++cas, ans);
    }
    return 0;
}
posted @ 2012-08-05 21:39  AC_Von  阅读(248)  评论(0编辑  收藏  举报