chuninsane

poj 2482 Stars in Your Window (线段树:区间更新)

题目链接:http://poj.org/problem?id=2482

读完题干不免有些心酸(🐶🐶🐶)
题意:有n个星星(星星i的坐标为xi, yi,亮度为ci),给你一个W*H的矩形,让你求得矩形能覆盖的星星的亮度和最大为多少

思路:矩形大小是固定的,所以可以换个方向思考,把矩形看成一个点(坐标为矩形中心),每个星星的影响区间范围为W*H的矩形(星星原来的坐标为矩形中心),该区间的亮度增加c。该问题就变成了求哪个点的亮度最大。坐标范围太大,直接暴力枚举不可能,所以需要预先进行离散化。在纵向上建立一个线段树,然后枚举横坐标,每次在线段树上将需要更新的纵向坐标区间进行更新,然后进行查询。

#include <iostream>
#include <cstring>
#include <algorithm>
#define maxn 10010
#define inf 1000000000
#define LL(x) x<<1
#define RR(x) x<<1|1
using namespace std;

typedef long long LL;

//variable define

struct line
{
    int x, y1, y2, light;
};

struct tree
{
    int l, r;
    int add;
    LL ma;
};

tree node[maxn * 8];
LL W, H, starx[maxn], stary[maxn], xx[maxn*2], yy[maxn*2];
int light[maxn], n;
line li[maxn * 4];

//function define

void push_down(int x);

void push_up(int x);

void build_tree(int left, int right, int x);

LL query(int left, int right, int x);

void update_add(int left, int right, int x, LL val);

bool line_compare( line l1, line l2);

int main(void)
{
    while (scanf("%d %lld %lld", &n, &W, &H) != EOF)
    {
        for (int i = 0; i < n; ++i)
        {
            scanf("%lld %lld %d", &starx[i], &stary[i], &light[i]);
            starx[i] *= 2;
            stary[i] *= 2;
        }
        for (int i = 0; i < n; ++i)
        {
            xx[i*2] = starx[i] - W;
            xx[i*2 + 1] = starx[i] + W;
            yy[i*2] = stary[i] - H;
            yy[i*2 + 1] = stary[i] + H - 1;
        }
        sort( xx, xx + 2*n);
        sort( yy, yy + 2*n);
        for (int i = 0; i < n; ++i)
        {
            int x, y1, y2;
            x = (int)(lower_bound( xx, xx + 2*n, starx[i] - W) - xx);
            y1 = (int)(lower_bound( yy, yy + 2*n, stary[i] - H) - yy);
            y2 = (int)(lower_bound( yy, yy + 2*n, stary[i] + H - 1) - yy);
            li[i*2].x = x;
            li[i*2].y1 = y1;
            li[i*2].y2 = y2;
            li[i*2].light = light[i];
            
            x = (int)(lower_bound( xx, xx + 2*n, starx[i] + W) - xx);
            li[i*2 + 1].x = x;
            li[i*2 + 1].y1 = y1;
            li[i*2 + 1].y2 = y2;
            li[i*2 + 1].light = -1*light[i];
        }
        build_tree( 1, 2*n, 1);
        LL ans = 0;
        sort( li, li + 2*n, line_compare);
        for (int i = 0; i < 2*n; ++i)
        {
            update_add( li[i].y1 + 1, li[i].y2 + 1, 1, li[i].light);
            ans = max( ans, query( li[i].y1 + 1, li[i].y2 + 1, 1));
        }
        printf("%lld\n", ans);
    }
    return 0;
}

void build_tree(int left, int right, int x)
{
    node[x].l = left;
    node[x].r = right;
    node[x].add = node[x].ma = 0;
    
    if (left == right)
        return;
    
    int lx = LL(x);
    int rx = RR(x);
    int mid = left + (right - left)/2;
    build_tree(left, mid, lx);
    build_tree(mid + 1, right, rx);
    push_up(x);
}

void push_up(int x)
{
    if (node[x].l >= node[x].r)
        return;
    
    int lx = LL(x);
    int rx = RR(x);
    node[x].ma = max( node[lx].ma, node[rx].ma);
}

void push_down(int x)
{
    if (node[x].l >= node[x].r)
        return;
    int lx = LL(x);
    int rx = RR(x);
    
    if (node[x].add != 0)
    {
        node[lx].add += node[x].add;
        node[rx].add += node[x].add;
        node[lx].ma += node[x].add;
        node[rx].ma += node[x].add;
    }
}

void update_add(int left, int right, int x, LL val)
{
    if (node[x].l == left && node[x].r == right)
    {
        node[x].add += val;
        node[x].ma += val;
        return;
    }
    push_down( x);
    node[x].add = 0;
    int lx = LL(x);
    int rx = RR(x);
    int mid = node[x].l + (node[x].r - node[x].l)/2;
    
    if (right <= mid)
        update_add(left, right, lx, val);
    else if (left > mid)
        update_add(left, right, rx, val);
    else
    {
        update_add(left, mid, lx, val);
        update_add(mid + 1, right, rx, val);
    }
    push_up(x);
}

LL query(int left, int right, int x)
{
    if (node[x].l == left && node[x].r == right)
    {
        return node[x].ma;
    }
    
    push_down(x);
    node[x].add = 0;
    
    int mid = node[x].l + (node[x].r - node[x].l)/2;
    int lx = LL(x);
    int rx = RR(x);
    LL result;
    if (right <= mid)
        result = query(left, right, lx);
    else if (left > mid)
        result = query(left, right, rx);
    else
        result = max( query(left, mid, lx), query(mid + 1, right, rx));
    push_up(x);
    return result;
}

bool line_compare(line l1, line l2)
{
    if (l1.x != l2.x)
        return l1.x < l2.x;
    return l1.light < l2.light;
}

 

posted on 2015-11-02 17:56  chuninsane  阅读(209)  评论(0编辑  收藏  举报

导航