Just Wait...

Dp优化1

数据结构优化 DP

前缀和

这个就不用多说了吧。

树状数组

类似树状数组优化 lis, 我们可以用其优化 dp。
当我们需要依靠二维偏序中的一维来确定是否可以转移时,我们可以用树状数组优化掉这一位的判断,即将这一维的 \(\mathcal O(n)\),优化成 \(\mathcal O(\log{n})\)

例题:ABC353_g

朴素DP:

我们可以用 \(f_i\) 表示当我们在第 \(i\) 个集市时有钱的最大数。那么其转移方程是非常显而易见:

\[f_i = max{f_j - C * |t_i - t_j|} + p_i \]

对于绝对值我们不好考虑,所以我们把绝对值拆成 \(t_i < t_j\)\(t_i > t_j\) 两种情况。 对于第一种情况,转移方程式如下:

\[f_i = max{f_j - C * t_j} + C * t_i + p_i \]

对于这个式子中有二元组:\(<t_i, f_i - C * t_i>\),我们将一维作为下标,维护二维最大值,使用树状数组,优化这个DP。

#include<bits/stdc++.h>

using namespace std;

#define int long long

const int N = 2e5 + 5;

int n, c, m, ans;
int tr1[N], tr2[N], f[N];

#define lowbit(x) x&(-x)

void add1(int x, int v){
	for(; x <= n; x += lowbit(x)) tr1[x] = max(tr1[x], v);
} 

void add2(int x, int v){
	for(; x; x -= lowbit(x)) tr2[x] = max(tr2[x], v);
}

int qry1(int x){
	int res = -2e18;
	for(; x; x -= lowbit(x)) res = max(res, tr1[x]);
	return res;
}

int qry2(int x){
	int res = -2e18;
	for(; x <= n; x += lowbit(x)) res = max(res, tr2[x]);
	return res;
} 

signed main(){
	ios::sync_with_stdio(0); cin.tie(nullptr), cout.tie(nullptr);
	cin>>n>>c>>m;
	memset(f, 0xcf, sizeof f);
	memset(tr1, 0xcf, sizeof tr1);
	memset(tr2, 0xcf, sizeof tr2);
	f[0] = 0;
	add1(1, c);
	add2(1, -c);
	for(int i = 1; i <= m; i ++){
		int t, p; cin>>t>>p;
		f[i] = max(qry1(t) - c * t + p, qry2(t) + c * t + p);
		add1(t, f[i] + c * t);
		add2(t, f[i] - c * t);
		ans = max(ans, f[i]);
	} 
	cout<<ans<<endl;
	return 0;
} 

线段树

使用线段树,优化决策点在某一段区间或转移到某段区间的时间复杂度。

例题1 CF 960F

先看朴素 DP。

我们用 \(f_i\) 表示枚举到第 \(i\) 条边时的 LIS,那么转移方程就是:\(f_i = f_j + 1 \hspace{0.3cm} (to_j = from_i \hspace{0.3cm} and \hspace{0.3cm} id_j < id_i)\),时间复杂度为 \(\mathcal O(n^2)\),可以拿 0pts

考虑如何优化。对于第一位枚举,我们没法压掉。但对于后一维的枚举,我们可以联想到 树状数组求LIS,的方法,我们需要上一条边的末节点一定是我们当前这条边的始节点,因此,我们可以将每一条边的 DP 值存在这条边的末节点里,使用线段树维护。最终时间复杂度:\(\mathcal O(n \log{n})\)

#include<bits/stdc++.h>

using namespace std;

const int N = 1e5 + 5;

int n, m;
int ls[N * 20 + 5], rs[N * 20 + 5], mx[N * 20 + 5], rt[N], idx;

int qry(int ql, int qr, int &u, int l, int r){
	if(!u || l > qr || r < ql) return 0;
	if(ql <= l && qr >= r) return mx[u];
	int mid = (l + r) >> 1;
	return max(qry(ql, qr, ls[u], l, mid), qry(ql, qr, rs[u], mid + 1, r));
} 

void mdf(int x, int v, int &u, int l, int r){
	if(r < x || l > x) return ;
	if(!u) u = ++idx;
	if(l == r) return mx[u] = max(mx[u], v), void();
	mdf(x, v, ls[u], l, (l + r) >> 1), mdf(x, v, rs[u], ((l + r) >> 1) + 1, r);
	mx[u] = max(mx[ls[u]], mx[rs[u]]); 
}

int main(){
	scanf("%d%d", &n, &m); 
	while(m --){
		int u, v, w; scanf("%d%d%d", &u, &v, &w);
		int t = qry(0, w - 1, rt[u], 0, 100000) + 1;
		mdf(w, t, rt[v], 0, 100000);
	}
	int ans = 0;
	for(int i = 1; i <= n; i ++) ans = max(ans, mx[rt[i]]);
	printf("%d\n", ans);
	return 0;
} 

例题2 CF 834D

简明题意

给你一个长度为 \(n\) 的序列 \(a\), 要将它分成 \(k\) 个连续子段。定义一个子段的价值为这个子段里的不同元素的个数。求所有子段价值之和的最大值。

做法

先看朴素 DP。

我们用 \(f_{i,\ j}\) 表示第 \(i\) 个时,有 \(k\) 个子段的最大值。
转移方程式为 \(f_{i, \ j} = max\{f_{k, \ j - 1}\}\)

直接优化并不好想,但是我们会发现,如果一个数 \(x\) 要对一段子区间有贡献,那么这个子区间的左端点一定在 \([pre_x, i - 1]\) 之间,其中,\(pre_x\) 表示 \(x\) 上一次出现的位置。

那么,对于 DP 的决策区间,一定在这个区间以内。因此,我们可以将 DP 值对于这个区间覆盖最大值,每次 DP 值从前面找最大即可。时间复杂度 \(\mathcal O(n \log{n})\)

#include<bits/stdc++.h>

using namespace std;

const int N = 35010;

int n, k;
int dp[N], a[N];
int pre[N], pos[N];

struct segtree{
	#define ls u << 1
	#define rs u << 1 | 1
	#define mid (l + r >> 1)
	#define segroot int u = 1, int l = 1, int r = n
	#define lson ls, l, mid
	#define rson rs, mid + 1, r
	
	struct { int mx, add; } tr[N << 2];
	
	void up(int u){ tr[u].mx = max(tr[ls].mx, tr[rs].mx); }
	
	void build(segroot){
		tr[u].add = tr[u].mx = 0;
		if(l == r) return tr[u].mx = dp[l - 1], void();
		build(lson), build(rson), up(u);
	}
	
	void down(int u, int x){ tr[u].add += x, tr[u].mx += x; }
	
	void down(int u){ down(ls, tr[u].add), down(rs, tr[u].add), tr[u].add = 0; }
	
	void add(int ql, int qr, int x, segroot){
		if(l > qr or r < ql) return;
		if(l >= ql and r <= qr) return down(u, x);
		down(u), add(ql, qr, x, lson), add(ql, qr, x, rson), up(u);
	}
	
	int qry(int ql, int qr, segroot){
		if(l > qr or r < ql) return 0;
		if(l >= ql and r <= qr) return tr[u].mx;
		down(u);
		return max(qry(ql, qr, lson), qry(ql, qr, rson));
	}
} s;

int main(){
	ios::sync_with_stdio(false); cin.tie(nullptr), cout.tie(nullptr);
	cin>>n>>k;
	for(int i = 1; i <= n; i ++){
		cin>>a[i]; 
		pre[i] = pos[a[i]] + 1;
		pos[a[i]] = i;
	}
	for(int i = 1; i <= k; i ++){
		s.build();
		for(int j = 1; j <= n; j ++){
			s.add(pre[j], j, 1);
			if(i - 1 <= j) dp[j] = s.qry(i - 1, j);
		}
	}
	cout<<dp[n];
	return 0;
}

决策单调性分治优化 DP

若决策满足单调性(即转移一定比上一个更靠前/后),可用分治优化复杂度到 \(\mathcal O(n \log{n})\)

例题1 最远点

题意

给你一个N个点的凸多边形,求离每一个点最远的点。

分治做法

我们发现在这道题中,决策点一定具有单调性。我们可以证明,当任一点的决策区间一定在凸包上,并且当点顺时针旋转时,它的决策点一定不会向逆时针方向旋转。因此,这道题具有决策单调性。

下面是主要代码:

void solve(int l, int r, int ql, int qr){ //[l, r] 为当前代转移区间 |||| [ql, qr] 为当前决策区间
	if(l > r) return ;
	int mid = (l + r) >> 1, qmid = mid; // 找 mid 对应的 DP 值
	for(int i = max(mid + 1, ql); i <= qr; i ++) if(getdis(i, mid) > getdis(qmid, mid)) qmid = i; // 暴力找 mid 对应的 DP 值
	ans[mid] = (qmid - 1) % n + 1;
	solve(l, mid - 1, ql, qmid); // 递归二分搜索
	solve(mid + 1, r, qmid, qr); // 递归二分搜索
} 

时间复杂度分析:搜索树共 \(\log{n}\) 层,每层的遍历总数为 \(n\),因此,总时间复杂度为 \(\mathcal O(n \log{n})\)

例题2 CF 834D

同理,这道题的转移决策也具有单调性。对于一个点,它的决策点一定不比上一个决策点靠前,因此具有决策单调性。

直接看代码吧。

#include<bits/stdc++.h>

using namespace std;

const int N = 35010;

int n, k;
int dp[55][N], a[N], ct[N], tot, L = 1, R = 0;

void add(int x){
	ct[a[x]] ++;
	if(ct[a[x]] == 1) tot ++;
}

void del(int x){
	ct[a[x]] --;
	if(ct[a[x]] == 0) tot --;
}

int calc(int l, int r){ // 暴力的去找到当前区间的答案
	while(L < l) del(L ++);
	while(L > l) add(-- L);
	while(R < r) add(++ R);
	while(R > r) del(R --);
	return tot;
}

void solve(int l, int r, int ql, int qr, int t){ // 前几个参数如上,t 表示当前的转移的区间数
	if(l > r) return;
	int mid = (l + r) >> 1, qmid = ql;
	for(int i = ql; i <= min(mid, qr); i ++) { // 暴力找答案
		int nw = dp[t - 1][i - 1] + calc(i, mid);
		if(nw > dp[t][mid]){
			dp[t][mid] = nw;
			qmid = i;
		}
	}
	solve(l, mid - 1, ql, qmid, t); // 递归求答案
	solve(mid + 1, r, qmid, qr, t);
} 

int main(){
	ios::sync_with_stdio(false); cin.tie(nullptr), cout.tie(nullptr);
	cin>>n>>k;
	for(int i = 1; i <= n; i ++){
		cin>>a[i]; 
	}
	for(int i = 1; i <= k; i ++){
		solve(1, n, 1, n, i);
	}
	cout<<dp[k][n];
	return 0;
}
posted @ 2025-08-10 22:02  Hty111  阅读(8)  评论(0)    收藏  举报