ACM学习历程——HDU2227 Find the nondecreasing subsequences(线段树 && dp)
Description
Input
Output
Sample Input
Sample Output
这个题目要求的是上升子序列的个数。
若设sum[i]表示以i为最后一个数的上升子序列的个数。
首先可以得到的是sum[i] = 1 + ∑(sum[j]) (a[j] <= a[i])。(ps:加1是因为子序列可以只包含一个数)
但是遍历sum[j]这个操作的时间需要O(n),所以要对这个操作进行优化。
很容易想到的是区间和,但是这个操作只需要求在i之前比a[i]小的那些点的和。
于是可以先把所有点的值初始化为0,然后从值最小的那个数开始求解,这样在求小的数的时候,大的数对应的值是0,这样的话大的数的贡献就是0,相当于没有加入计算。
举例说明:
对于序列5 1 3 2 4
先求1这个数,那么sum[1]就是[1, 2]区间内val[i]的和加1,即0+1。此时val[1]更新为1,sum[1]为1。
再求2这个数,那么sum[2]就是[1, 4]区间内val[i]的和加1,即1+1。此时val[2]更新为2,sum[2]为2。
再求3这个数,那么sum[3]就是[1, 3]区间内val[i]的和加1,即1+1。此时val[3]更新为2,sum[3]为2。
再求4这个数,那么sum[4]就是[1, 5]区间内val[i]的和加1,即1+2+2+1。此时val[4]更新为5,sum[4]为6。
最后求5这个数,那么sum[5]就是[1, 1]区间内val[i]的和加1,即0+1。此时val[5]更新为1,sum[5]为1。
所以答案就是sum[i]的和,就是12。
由于可以边求边加,所以sum数组可以省去。
代码:
#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <algorithm>
#define LL long long
#define N 1000000007
using namespace std;
//线段树
//区间每点增值,求区间和
const int maxn = 100005;
struct node
{
int lt, rt;
int val;
}tree[4*maxn];
//向上更新
void PushUp(int id)
{
tree[id].val = (tree[id<<1].val + tree[id<<1|1].val)%N;
}
//建立线段树
void Build(int lt, int rt, int id)
{
tree[id].lt = lt;
tree[id].rt = rt;
tree[id].val = 0;//每段的初值,根据题目要求
if (lt == rt)
{
//tree[id].val = 1;
return;
}
int mid = (lt + rt) >> 1;
Build(lt, mid, id<<1);
Build(mid+1, rt, id<<1|1);
//PushUp(id);
}
//增加区间内每个点固定的值
void Add(int lt, int rt, int id, int pls)
{
if (lt <= tree[id].lt && rt >= tree[id].rt)
{
tree[id].val += pls * (tree[id].rt-tree[id].lt+1);
tree[id].val %= N;
return;
}
int mid = (tree[id].lt + tree[id].rt) >> 1;
if (lt <= mid)
Add(lt, rt, id<<1, pls);
if (rt > mid)
Add(lt, rt, id<<1|1, pls);
PushUp(id);
}
//查询某段区间内的he
LL Query(int lt, int rt, int id)
{
if (lt <= tree[id].lt && rt >= tree[id].rt)
return tree[id].val;
int mid = (tree[id].lt + tree[id].rt) >> 1;
LL ans = 0;
if (lt <= mid)
ans += Query(lt, rt, id<<1);
if (rt > mid)
ans += Query(lt, rt, id<<1|1);
return ans%N;
}
struct node1
{
LL val;
int id;
}p[100005];
bool cmp(node1 a, node1 b)
{
if (a.val != b.val)
return a.val < b.val;
else
return a.id < b.id;
}
int n, len;
int sum;
LL ans;
int main()
{
//freopen("test.in", "r", stdin);
while (scanf("%d", &n) != EOF)
{
for (int i = 0; i < n; ++i)
{
p[i].id = i+1;
scanf("%I64d", &p[i].val);
}
sort(p, p+n, cmp);
ans = 0;
Build(1, n, 1);
for (int i = 0; i < n; ++i)
{
sum = Query(1, p[i].id, 1)+1;
Add(p[i].id, p[i].id, 1, sum);
ans += sum;
ans %= N;
}
printf("%I64d\n", ans);
}
return 0;
}
posted on 2015-05-01 19:50 AndyQsmart 阅读(200) 评论(0) 收藏 举报
浙公网安备 33010602011771号