51nod 1564 区间的价值 | 分治 尺取法

51nod 1564 区间的价值

题面

一个区间的价值是区间最大值×区间最小值。给出一个序列\(a\), 求出其中所有长度为k的子区间的最大价值。对于\(k = 1, 2, ..., n\)输出答案。
保证序列随机生成

题解

我的做法是\(O(n \log n)\)的!

对于一个区间[l, r],取其中的最大值,最大值的下标设为mid。对于[l, mid - 1]和[mid + 1, r]两个子区间内的点对,都可以递归处理,所以我们只需关注横跨mid的点对(左端点在[l, mid], 右端点在[mid, r])。

采用two pointers(尺取法?)来更新答案。设置两个指针pl, pr,分别在[l, mid]和[mid, r]中,表示当前点对的左右端点。初始pl, pr都是mid。因为我们的目标是区间价值最大,那么已知区间最大值和区间长度时,最小值越大越好,于是移动指针的时候选择pl - 1和pr + 1中值较小的那个数加入当前区间,并更新对应长度的答案。

因为数据随机,所以期望复杂度是\(O(n \log n)\)

核心代码:

void solve(int l, int r){
    int mid = l;
    for(int i = l; i <= r; i++)
	if(a[mid] < a[i]) mid = i;
    for(int pl = mid, pr = mid + 1, mi = a[mid]; pl >= l || pr <= r;){
	if(pl >= l && (pr > r || a[pl] > a[pr])){
	    mi = min(mi, a[pl--]);
	    ans[pr - pl - 1] = max(ans[pr - pl - 1], (ll)a[mid] * mi);
	}
	else{
	    mi = min(mi, a[pr++]);
	    ans[pr - pl - 1] = max(ans[pr - pl - 1], (ll)a[mid] * mi);
	}
    }
    if(l < mid) solve(l, mid - 1);
    if(mid < r) solve(mid + 1, r);
}

完整代码:

#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cmath>
#define space putchar(' ')
#define enter putchar('\n')
#define INF 0x3f3f3f3f
using namespace std;
typedef long long ll;
template <class T>
void read(T &x){
    char c;
    bool op = 0;
    while(c = getchar(), c < '0' || c > '9')
	if(c == '-') op = 1;
    x = c - '0';
    while(c = getchar(), c >= '0' && c <= '9')
	x = x * 10 + c - '0';
    if(op) x = -x;
}
template <class T>
void write(T x){
    if(x < 0) putchar('-'), x = -x;
    if(x >= 10) write(x / 10);
    putchar('0' + x % 10);
}
const int N = 100005;
int n, a[N];
ll ans[N];
void solve(int l, int r){
    int mid = l;
    for(int i = l; i <= r; i++)
	if(a[mid] < a[i]) mid = i;
    for(int pl = mid, pr = mid + 1, mi = a[mid]; pl >= l || pr <= r;){
	if(pl >= l && (pr > r || a[pl] > a[pr])){
	    mi = min(mi, a[pl--]);
	    ans[pr - pl - 1] = max(ans[pr - pl - 1], (ll)a[mid] * mi);
	}
	else{
	    mi = min(mi, a[pr++]);
	    ans[pr - pl - 1] = max(ans[pr - pl - 1], (ll)a[mid] * mi);
	}
    }
    if(l < mid) solve(l, mid - 1);
    if(mid < r) solve(mid + 1, r);
}
int main(){
    read(n);
    for(int i = 1; i <= n; i++)
	read(a[i]);
    solve(1, n);
    for(int i = 1; i <= n; i++)
	write(ans[i]), enter;
    return 0;
}
posted @ 2017-11-22 15:43  胡小兔  阅读(487)  评论(0编辑  收藏  举报