DP 优化杂谈
关于 DP 优化
DP 需要优化什么?优化转移的时间复杂度。
我们可以将一个 \(O(N)\) 的转移优化到 \(O(\log N)\) 或者 \(O(1)\) 的。
DP 优化只是一个手段,我们其实可以理解为换了一个转移方式,但是转移的内容不变。
所以遇到需要 DP 优化的题目,不要先去想怎么优化。打出暴力之后,很快就会看出可以优化之处。
DP 优化——前缀和优化
形如这样的 DP 转移式子:
这时候如果我们暴力求和,时间复杂度会增加一个 \(O(N)\),对于较大的数据显然不行。
我们发现 \(j-R\sim j-L\) 是一个连续的区间,而我们要求的是和,很容易想到前缀和来优化。
于是有了:
有一些细节需要注意:
在前缀和中,求区间 \(l\sim r\) 的和是 \(s_r-s_{l-1}\)。
但是有时候,\(s_0\) 也有实际的值。当 \(l\) 取到 0 时,\(l-1=-1\),此时肯定会 RE。我们需要避免下标取到负数,但是肯定不能这样写:
ans=s[r]-s[max(0,l-1)];//l 此时等于 0
因为 \(s_0\) 也有实际的值,本来应该求 \(0\sim r\) 的和,却求成了 \(1\sim r\) 的和。
为了避免这种情况,应该这么写:
ans=s[r]-(l-1<0?0:s[l-1]);
题目
力扣1871 跳跃游戏
没有传送门
题意简述
从点 0 开始,每次可以跳 \(L\sim R\) 格,有些点无法落脚,求是否能到点 \(n-1\)。
思路
首先考虑暴力的 DP,很容易有:
\(dp_i\) 表示是否可以到达点 \(i\)。\(dp_i=0\) 表示无法到达,\(dp_i>0\) 表示可以到达。
此时完全符合前缀和优化 DP 的板子,每走一步计算当前的前缀和即可。
code
#include<bits/stdc++.h>
//#define lc p<<1
//#define rc p<<1|1
#define endl putchar('\n')
#define psp putchar(' ')
using namespace std;
typedef unsigned long long ull;
typedef long long ll;
int read(){
int x=0,f=1;
char c=getchar();
while(c<'0'||c>'9'){if(c=='-')f=-1;c=getchar();}
while(c>='0'&&c<='9')x=(x<<3)+(x<<1)+c-'0',c=getchar();
return x*f;
}
void print(int x){
if(x<0)putchar('-'),x=-x;
if(x<10){putchar(x+'0');return;}
print(x/10);
putchar(x%10+'0');
}
void putstr(string s){
for(int i=0;i<s.size();i++)putchar(s[i]);
}
int lowbit(int x){
return x&-x;
}
int n,m,k;
int T;
int L,R;
int s[1000005];
signed main(){
//ios::sync_with_stdio(0);
cin>>L>>R;
string t;
cin>>t;
n=t.length();
t=" "+t;
s[1]=1;
for(int i=2;i<=n;i++){
if(s[max(0,i-L)]-s[max(0,i-R-1)]>0&&t[i]=='0')s[i]=1;
s[i]+=s[i-1];
}
if(s[n]-s[n-1])cout<<"Yes";
else cout<<"No";
}
AT_dp_m Candies
题意简述
给 \(n\) 个人分 \(k\) 个糖,第 \(i\) 个人只能分 \(0\sim a_i\) 个。求刚好分完的方案数。
思路
用 \(dp_{i,j}\) 表示考虑完了前 \(i\) 个人,共分了 \(j\) 个糖的方案数。
那么有:
依然是一段连续的区间求和,可以使用前缀和,只是要注意在前文中提到的细节。
code
#include<bits/stdc++.h>
#define int long long
//#define lc p<<1
//#define rc p<<1|1
#define endl putchar('\n')
#define psp putchar(' ')
using namespace std;
typedef unsigned long long ull;
typedef long long ll;
const int N=105;
const int M=1e5+5;
const int mod=1e9+7;
int read(){
int x=0,f=1;
char c=getchar();
while(c<'0'||c>'9'){if(c=='-')f=-1;c=getchar();}
while(c>='0'&&c<='9')x=(x<<3)+(x<<1)+c-'0',c=getchar();
return x*f;
}
void print(int x){
if(x<0)putchar('-'),x=-x;
if(x<10){putchar(x+'0');return;}
print(x/10);
putchar(x%10+'0');
}
void putstr(string s){
for(int i=0;i<s.size();i++)putchar(s[i]);
}
int lowbit(int x){
return x&-x;
}
int n,m,k;
int T;
int a[N];
int dp[N][M];
int cnt[N][M];
signed main(){
//ios::sync_with_stdio(0);
n=read(),k=read();
for(int i=1;i<=n;i++)a[i]=read();
for(int i=0;i<=k;i++)cnt[0][i]=1;
dp[0][0]=1;
for(int i=1;i<=n;i++){
dp[i][0]=dp[i-1][0];
cnt[i][0]=dp[i-1][0];
for(int j=1;j<=k;j++){
dp[i][j]=((cnt[i-1][j]-(j-a[i]-1<0?0:cnt[i-1][j-a[i]-1]))%mod+mod)%mod;
cnt[i][j]=(cnt[i][j-1]+dp[i][j])%mod;
}
}
print(dp[n][k]%mod);
}
LOJ6077 [2017 山东一轮集训 Day7]逆序对(弱化版)
没有传送门
题意简述
求 \(1\sim n\) 的全排列中逆序对数量恰好为 \(k\) 的排列数量。
数据范围修改为 \(1\le n,k\le 5000\)
思路
状态设计:\(dp_{i,j}\) 表示 \(1\sim i\) 的排列中逆序对数量为 \(k\) 的数量。
接着考虑转移。
我们每次向现在的排列中插入一个数,为了方便,我们从 \(1\) 到 \(i\) 开始插入。
首先明确:插入数字不会打乱原来的顺序,即插入数字不会影响原序列中数字两两之间的逆序对数量。
当 \(i\) 被插入时,\(1\sim i-1\) 都已经被插入了,并且有 \(x\in [1,i-1]<i\)。
若把 \(i\) 放在序列的末尾,不会增加逆序对。
若把 \(i\) 放在序列的倒数第一位之前,则一定会增加一个逆序对。
若把 \(i\) 放在序列的倒数第二位之前,则会增加两个逆序对。
……
若把 \(i\) 放在序列第一位,则会增加 \(i-1\) 个逆序对。
所以,可以得出,状态 \(dp_{i,j}\) 可以由 \(dp_{i-1,k(j-i+1\le k\le j)}\) 转移而来,即:
依旧可以前缀和优化,注意细节。
code
#include<bits/stdc++.h>
//#define int long long
//#define lc p<<1
//#define rc p<<1|1
#define endl putchar('\n')
#define psp putchar(' ')
using namespace std;
typedef unsigned long long ull;
typedef long long ll;
const int N=5005;
const int mod=1e9+7;
int read(){
int x=0,f=1;
char c=getchar();
while(c<'0'||c>'9'){if(c=='-')f=-1;c=getchar();}
while(c>='0'&&c<='9')x=(x<<3)+(x<<1)+c-'0',c=getchar();
return x*f;
}
void print(int x){
if(x<0)putchar('-'),x=-x;
if(x<10){putchar(x+'0');return;}
print(x/10);
putchar(x%10+'0');
}
void putstr(string s){
for(int i=0;i<s.size();i++)putchar(s[i]);
}
int lowbit(int x){
return x&-x;
}
int n,m,k;
int T;
int dp[N][N];
int cnt[N][N];
signed main(){
n=read(),k=read();
dp[1][0]=1;
for(int i=0;i<=k;i++)cnt[1][i]=1;
for(int i=2;i<=n;i++){
for(int j=0;j<=k;j++){
dp[i][j]=((cnt[i-1][j]-(j-i<0?0:cnt[i-1][j-i]))%mod+mod)%mod;
cnt[i][j]=(cnt[i][j-1]+dp[i][j])%mod;
}
}
print(dp[n][k]);
}
AT_abc253_e [ABC253E] Distance Sequence
题意简述
要求一个长度为 \(N\) 的序列满足元素在 \(1\sim M\) 范围内,且相邻元素相差至少为 \(K\),求序列数量。
思路
用 \(dp_{i,j}\) 表示考虑前 \(i\) 个位置,第 \(i\) 个位置选择 \(j\) 的方案数。
先考虑暴力:
这个看似不连续,但其实可以将 \(|k-j|\geq K\) 分解成 \(k-j\geq K\) 或 \(j-k\geq K\),即:\(k\geq K+j\) 或 \(k\le j-K\),所以最后只需要求 \(K+j\sim M\) 和 \(1\sim j-K\) 这两个区间得和就行了。
code
#include<bits/stdc++.h>
#define int long long
//#define lc p<<1
//#define rc p<<1|1
#define endl putchar('\n')
#define psp putchar(' ')
using namespace std;
typedef unsigned long long ull;
typedef long long ll;
const int N=1005;
const int M=5005;
const int mod=998244353;
int read(){
int x=0,f=1;
char c=getchar();
while(c<'0'||c>'9'){if(c=='-')f=-1;c=getchar();}
while(c>='0'&&c<='9')x=(x<<3)+(x<<1)+c-'0',c=getchar();
return x*f;
}
void print(int x){
if(x<0)putchar('-'),x=-x;
if(x<10){putchar(x+'0');return;}
print(x/10);
putchar(x%10+'0');
}
void putstr(string s){
for(int i=0;i<s.size();i++)putchar(s[i]);
}
int lowbit(int x){
return x&-x;
}
int n,m,k;
int T;
int dp[N][M];
int cnt[N][M];
signed main(){
//ios::sync_with_stdio(0);
n=read(),m=read(),k=read();
for(int i=1;i<=m;i++)dp[1][i]=1,cnt[1][i]=(cnt[1][i-1]+dp[1][i])%mod;
if(k==0){
int ans=1;
for(int i=1;i<=n;i++)ans=(ans*m)%mod;
print(ans);
return 0;
}
for(int i=2;i<=n;i++){
for(int j=1;j<=m;j++){
dp[i][j]=((dp[i][j]+cnt[i-1][max(0ll,j-k)]-cnt[i-1][0])%mod+mod)%mod;
dp[i][j]=((dp[i][j]+cnt[i-1][m]-cnt[i-1][min(m,j+k-1)])%mod+mod)%mod;
cnt[i][j]=(cnt[i][j-1]+dp[i][j])%mod;
}
}
int ans=0;
for(int i=1;i<=m;i++)ans=(ans+dp[n][i])%mod;
print(ans);
}
P1107 [BJWC2008] 雷涛的小猫
题意简述
有 \(N\) 棵树,高度为 \(H\),不同高度处有柿子。最开始选择任意一棵,有两种操作:
- 在当前树上向下滑一格。
- 跳到另一棵树,向下掉 \(Delta\) 格。
到达柿子处即可吃掉,问到达高度 0 之前最多可以吃到多少柿子。
思路
有两种选择:
- 树不变,向下走一格。
- 树变,向下走 \(Delta\) 格。
用 \(dp_{i,j}\) 表示到达第 \(i\) 棵树的 \(j\) 高度时最多可以吃到多少柿子。
那么有两种转移到 \(dp_{i,j}\) 的方式:
- 由 \(dp_{i,j+1}\) 转移而来。
- 由 \(dp_{1\ldots n,j+Delta}\) 转移而来。
第一种方式显然可以直接转移,而第二种方法需要求出 \(\max_{1\le k\le n}dp_{k,j+Delta}\)。
现在需要的是最大值,我们没有听说过前缀最大值的求法(线段树走开)。
但是我们发现,第一维的范围始终都是 \(1\sim n\),覆盖了所有树。
所以我们只需要知道高度 \(j\) 的条件下,所有树的最大价值即可。
code
#include<bits/stdc++.h>
#define int long long
//#define lc p<<1
//#define rc p<<1|1
#define endl putchar('\n')
#define psp putchar(' ')
using namespace std;
typedef unsigned long long ull;
typedef long long ll;
const int N=2005;
const int M=5005;
int read(){
int x=0,f=1;
char c=getchar();
while(c<'0'||c>'9'){if(c=='-')f=-1;c=getchar();}
while(c>='0'&&c<='9')x=(x<<3)+(x<<1)+c-'0',c=getchar();
return x*f;
}
void print(int x){
if(x<0)putchar('-'),x=-x;
if(x<10){putchar(x+'0');return;}
print(x/10);
putchar(x%10+'0');
}
void putstr(string s){
for(int i=0;i<s.size();i++)putchar(s[i]);
}
int lowbit(int x){
return x&-x;
}
int n,m,k;
int val[N][N];
int mx[N];
int h;
int d;
int T;
int dp[N][N];
signed main(){
//ios::sync_with_stdio(0);
n=read(),h=read(),d=read();
for(int i=1;i<=n;i++){
T=read();
while(T--){
int v=read();
val[i][v]++;
}
}
for(int j=h;j>=0;j--){
for(int i=1;i<=n;i++){
dp[i][j]=val[i][j]+max(dp[i][j+1],mx[j+d]);
mx[j]=max(mx[j],dp[i][j]);
}
}
print(mx[0]);
}
牛客33634G 小人国的粉刷匠
题目传送门坏了
题意简述
有 \(n\) 个点,\(m\) 种颜色,给点 \(i\) 染颜色 \(j\) 的代价是 \(a_{i,j}\)。有 \(q\) 个要求,要求点 \(u_i\) 必须染成颜色 \(v_i\)。要求所有点必须被分成 \(k\) 块连续相同颜色的点,求最小染色代价。
思路
题目非常复杂啊。
首先,找到重要的信息:当前点,颜色,分的块数。
所以有 \(dp_{i,j,l}\) 表示已考虑前 \(i\) 个点,第 \(i\) 个点的颜色是 \(j\),已经分成 \(l\) 块的最小代价。
然后考虑转移。
若点 \(i-1\) 的颜色是 \(j\),则点 \(i\) 和点 \(i-1\) 在同一块内,分的块数不变。
若点 \(i-1\) 的颜色不是 \(j\),则两点无法分到一块,前一个状态分的块数少一。
这个时候第二种转移的完整区间被截断了,但是断点前的区间是连续的,断点后的也如此,并且两个区间的其中一个端点都在边界上。这个时候可以做一个前缀最小值和一个后缀最小值,这样就可以避免计算 \(j\) 了。
code
#include<bits/stdc++.h>
#define int unsigned long long
//#define lc p<<1
//#define rc p<<1|1
#define endl putchar('\n')
#define psp putchar(' ')
using namespace std;
typedef unsigned long long ull;
typedef long long ll;
const int N=105;
int read(){
int x=0,f=1;
char c=getchar();
while(c<'0'||c>'9'){if(c=='-')f=-1;c=getchar();}
while(c>='0'&&c<='9')x=(x<<3)+(x<<1)+c-'0',c=getchar();
return x*f;
}
void print(int x){
if(x<0)putchar('-'),x=-x;
if(x<10){putchar(x+'0');return;}
print(x/10);
putchar(x%10+'0');
}
void putstr(string s){
for(int i=0;i<s.size();i++)putchar(s[i]);
}
int lowbit(int x){
return x&-x;
}
int n,m,k;
int T;
int c[N];
int a[N][N];
int dp[N][N][N];
int L[N][N][N],R[N][N][N];
signed main(){
//ios::sync_with_stdio(0);
n=read(),m=read(),k=read(),T=read();
while(T--){
int u=read(),v=read();
c[u]=v;
}
for(int i=1;i<=n;i++){
for(int j=1;j<=m;j++){
a[i][j]=read();
}
}
memset(dp,0x3f,sizeof(dp));
// dp[0][0][0]=0;
memset(L,0x3f,sizeof(L));
memset(R,0x3f,sizeof(R));
for(int i=1;i<=m;i++){
dp[1][i][1]=a[1][i];
if(c[1]&&c[1]!=i)dp[1][i][1]=1e12;
}
for(int i=1;i<=m;i++){
L[1][i][1]=min(L[1][i-1][1],dp[1][i][1]);
}
for(int i=m;i>=1;i--){
R[1][i][1]=min(R[1][i+1][1],dp[1][i][1]);
}
for(int i=2;i<=n;i++){
for(int j=1;j<=m;j++){
if(c[i]&&c[i]!=j){
for(int l=1;l<=k;l++){
dp[i][j][l]=1e12;
}
continue;
}
for(int l=1;l<=k;l++){
dp[i][j][l]=min(dp[i-1][j][l],min(L[i-1][j-1][l-1],R[i-1][j+1][l-1]))+a[i][j];
}
}
for(int l=1;l<=k;l++){
for(int j=1;j<=m;j++){
L[i][j][l]=min(L[i][j-1][l],dp[i][j][l]);
}
for(int j=m;j>=1;j--){
R[i][j][l]=min(R[i][j+1][l],dp[i][j][l]);
}
}
}
int ans=1e12;
for(int i=1;i<=m;i++){
ans=min(ans,dp[n][i][k]);
}
if(ans>=1e12)putstr("-1");
else print(ans);
}
P10741 [SEERC 2020] Fence Job
题意简述
有 \(n\) 个互不相同的数 \(a_1\ldots a_n\),每次操作可以选择一个区间 \([l,r]\),将区间内所有数改成 \(\min_{i\in[l,r]}a_i\)。可以进行若干次操作(可以不操作)。求可以得到多少种不同的区间。
思路
我们发现,每一个数都有一个影响的范围。
如序列:
1 5 6 7 4 8 3 2
序列中要将一个区间内的元素改成 4,若选择区间 \([2,6]\) 就可以完成,选择区间 \([1,6]\) 则不行,因为区间 \([1,6]\) 中的最小值是 1。
我们用 \(l_i\) 表示从第 \(i\) 个数出发向左最后一个大于 \(a_i\) 的数的位置,\(r_i\) 表示从第 \(i\) 个数出发向右最后一个大于 \(a_i\) 的数的位置,那么第 \(i\) 个数可以影响到的范围就是区间 \([l_i,r_i]\)。
用 \(dp_i\) 表示前 \(i\) 个数可以得到多少种序列。
如何计算?
对于点 \(i\),我们枚举 \(l_i\le j\le r_i\),令 \([j,r_i]\) 全部覆盖为 \(a_i\),\([l_i,j-1]\) 全部不覆盖 \(a_i\)。
当 \(j>i\) 时,如何让 \([l_i,j-1]\) 全部不覆盖 \(a_i\)?
首先,肯定有 \(a_{l_i-1}<a_i\),所以 \([l_i,j-1]\) 可以被 \(a_{l_i-1}\) 及其之前更小的数覆盖。
code
#include<bits/stdc++.h>
using namespace std;
const int N=3005;
const int mod=1e9+7;
int read(){
int x=0;
char c=getchar();
while(c<'0'||c>'9'){c=getchar();}
while(c>='0'&&c<='9')x=(x<<3)+(x<<1)+c-'0',c=getchar();
return x;
}
void print(int x){
char buf[50];
int cnt=0;
while(x){
buf[++cnt]=x%10+'0';
x/=10;
}
while(cnt)putchar(buf[cnt--]);
}
int n;
int h[N];
int dp[N];
signed main(){
//ios::sync_with_stdio(0);
n=read();
for(int i=1;i<=n;i++)h[i]=read();
dp[0]=1;
for(int i=1;i<=n;i++){
int l=i,r=i;
while(h[l-1]>=h[i])l--;
while(h[r+1]>=h[i])r++;
for(int j=l;j<=r;j++){
dp[j]=(dp[j]+dp[j-1])%mod;
}
}
print(dp[n]);
}
DP 优化——线段树优化
总结前面的前缀和优化 DP,有几个限制条件:
- 修改是单点。
- 修改的点在这次循环中不会用到,即用在循环 \(i\) 中用循环 \(i-1\) 的答案。
当我们遇到一个区间但是有多个边界的限制时,区间的连续性将不是那么明显,可能无法用现在的下表表示,前缀和就无能为力了。
形如:
这时候既要满足 \(j<i\),又要满足 \(a_j<a_i\)。虽然下标仍是连续的区间,但是 \(a_i\) 无序,无法用前缀和。
这时候就需要用到线段树。
可以用到线段树优化的结构还有:
这时候,若 \(val(i,j)\) 按照一定的顺序增加/减少,可以使用线段树维护,将转移从 \(O(N^2)\) 降到 \(O(N\log N)\)。
题目
严格上升子序列的数量
没找到
题意简述
求长度为 \(n\) 序列 \(a\) 中严格上升子序列的数量,只要下标不同即为不同。
\(1\le n\le 10^5,1\le a_i\le 10^9\)
思路
首先,肯定是考虑朴素做法。
\(dp_i\) 表示必选第 \(i\) 个数的答案数。
那么转移有:
这是时时间复杂度是 \(O(N^2)\) 级别,考虑优化。
转移的过程就是求 \(1\sim i-1\) 中值在 \(1\sim a_i-1\) 的数之和。
这明显就能想到权值线段树。
code
#include<bits/stdc++.h>
#define int long long
#define lc f[p].l
#define rc f[p].r
#define endl putchar('\n')
#define psp putchar(' ')
using namespace std;
typedef unsigned long long ull;
typedef long long ll;
const int N=1e5+5;
const int mod=1e9+7;
int read(){
int x=0,f=1;
char c=getchar();
while(c<'0'||c>'9'){if(c=='-')f=-1;c=getchar();}
while(c>='0'&&c<='9')x=(x<<3)+(x<<1)+c-'0',c=getchar();
return x*f;
}
void print(int x){
if(x<0)putchar('-'),x=-x;
if(x<10){putchar(x+'0');return;}
print(x/10);
putchar(x%10+'0');
}
void putstr(string s){
for(int i=0;i<s.size();i++)putchar(s[i]);
}
int lowbit(int x){
return x&-x;
}
int n,m,k;
int T;
int a[N];
struct node{
int l,r,ans;
}f[30*N];
int rt;
int idx;
void update(int &p,int L,int R,int x,int k){
if(!p)p=++idx;
if(x<L||x>R)return;
if(L==R){
f[p].ans=(f[p].ans+k)%mod;
return;
}
int mid=L+R>>1;
update(lc,L,mid,x,k);
update(rc,mid+1,R,x,k);
f[p].ans=(f[lc].ans+f[rc].ans)%mod;
}
int query(int p,int L,int R,int l,int r){
if(l<=L&&R<=r)return f[p].ans;
int mid=L+R>>1;
int res=0;
if(l<=mid)res=(res+query(lc,L,mid,l,r))%mod;
if(r>mid)res=(res+query(rc,mid+1,R,l,r))%mod;
return res;
}
int dp[N];
int ans=0;
signed main(){
//ios::sync_with_stdio(0);
n=read();
for(int i=1;i<=n;i++)a[i]=read();
for(int i=1;i<=n;i++){
dp[i]=query(rt,0,1e9,0,a[i]-1)+1;
update(rt,0,1e9,a[i],dp[i]);
ans=(ans+dp[i])%mod;
}
print(ans);
}
LOJ 6077.「2017 山东一轮集训 Day7」逆序对
这个题在前文前缀和优化中出现过。
当时可以使用前缀和优化,是因为当时使用的是填表法[1]。
而如果使用刷表法[2],就没法用前缀和了,因为使用填表法每次修改是一个区间。
先写出朴素的转移:
for(int l=j;l<=min(m,j+i);l++)dp[i+1][l]+=dp[i][j];
发现修改是区间的,考虑用线段树。
线段树 code
#include<bits/stdc++.h>
#define int long long
#define lc p<<1
#define rc p<<1|1
#define endl putchar('\n')
#define psp putchar(' ')
using namespace std;
typedef unsigned long long ull;
typedef long long ll;
const int mod=1e9+7;
int read(){
int x=0,f=1;
char c=getchar();
while(c<'0'||c>'9'){if(c=='-')f=-1;c=getchar();}
while(c>='0'&&c<='9')x=(x<<3)+(x<<1)+c-'0',c=getchar();
return x*f;
}
void print(int x){
if(x<0)putchar('-'),x=-x;
if(x<10){putchar(x+'0');return;}
print(x/10);
putchar(x%10+'0');
}
void putstr(string s){
for(int i=0;i<s.size();i++)putchar(s[i]);
}
int lowbit(int x){
return x&-x;
}
int n,m,k;
int T;
int dp[5005][5005];
struct node{
int l,r,ans,add;
}f[4*5005];
void build(int p,int l,int r){
f[p]=node{l,r,0,0};
if(l==r)return;
int mid=l+r>>1;
build(lc,l,mid);
build(rc,mid+1,r);
f[p].ans=f[lc].ans+f[rc].ans;
}
void down(int p){
if(f[p].add){
f[lc].add=(f[lc].add+f[p].add)%mod;
f[rc].add=(f[rc].add+f[p].add)%mod;
f[lc].ans=(f[lc].ans+f[p].add*(f[lc].r-f[lc].l+1)%mod)%mod;
f[rc].ans=(f[rc].ans+f[p].add*(f[rc].r-f[rc].l+1)%mod)%mod;
f[p].add=0;
}
}
void update(int p,int l,int r,int k){
if(l<=f[p].l&&f[p].r<=r){
f[p].ans=(f[p].ans+k*(f[p].r-f[p].l+1)%mod)%mod;
f[p].add=(f[p].add+k)%mod;
return;
}
down(p);
int mid=f[p].l+f[p].r>>1;
if(l<=mid)update(lc,l,r,k);
if(r>mid)update(rc,l,r,k);
f[p].ans=(f[lc].ans+f[rc].ans)%mod;
}
int query(int p,int x){
if(f[p].l==f[p].r)return f[p].ans;
down(p);
int mid=f[p].l+f[p].r>>1;
if(x<=mid)return query(lc,x);
return query(rc,x);
}
signed main(){
//ios::sync_with_stdio(0);
n=read(),m=read();
build(1,0,m);
update(1,0,0,1);
for(int i=1;i<=n;i++){
for(int j=0;j<=m;j++){
dp[i][j]=query(1,j);
}
build(1,0,m);
for(int j=0;j<=m;j++){
update(1,j,min(m,j+i),dp[i][j]);
}
}
print(dp[n][m]);
}
但是我们发现,虽然修改是区间的,但是最终的查询是单点的,并且在循环 \(i\) 中的修改只会对循环 \(i+1\) 有影响。
这是什么?很多次区间修改但是最终每个点只查询一遍:差分。
对比线段树少了一个 \(O(\log N)\),线段树似乎没用!?
差分 code
#include<bits/stdc++.h>
#define int long long
#define endl putchar('\n')
#define psp putchar(' ')
using namespace std;
typedef unsigned long long ull;
typedef long long ll;
const int mod=1e9+7;
int read(){
int x=0,f=1;
char c=getchar();
while(c<'0'||c>'9'){if(c=='-')f=-1;c=getchar();}
while(c>='0'&&c<='9')x=(x<<3)+(x<<1)+c-'0',c=getchar();
return x*f;
}
void print(int x){
if(x<0)putchar('-'),x=-x;
if(x<10){putchar(x+'0');return;}
print(x/10);
putchar(x%10+'0');
}
void putstr(string s){
for(int i=0;i<s.size();i++)putchar(s[i]);
}
int lowbit(int x){
return x&-x;
}
int n,m,k;
int T;
int dp[5020][5020];
signed main(){
//ios::sync_with_stdio(0);
n=read(),m=read();
dp[1][0]=1;
for(int i=1;i<=n;i++){
for(int j=0;j<=m;j++){
dp[i+1][j]=(dp[i+1][j]+dp[i][j])%mod;
if(j+i+1<=m)dp[i+1][j+i+1]=((dp[i+1][j+i+1]-dp[i][j])%mod+mod)%mod;
}
for(int j=1;j<=m;j++){
dp[i+1][j]=(dp[i+1][j]+dp[i+1][j-1])%mod;
}
}
print((dp[n][m]%mod+mod)%mod);
}
那么为什么这道题可以不用线段树,而上一道题却需要??
上一道题中我们修改的是 \(dp_i\),而 \(dp_i\) 可能下一步就会在 \(dp_{i+1}\) 的转移中被查询。
但是这道题不同。
通过上面两道题道题,我们可以总结出一些必须要用线段树优化的条件:
- 区修区查(显而易见)。
- 修改和查询没有分开。
P3970 [TJOI2014] 上升子序列
题意简述
求一个序列中长度至少为 2,数值不同的严格上升子序列数量。
思路
这道题和第一题的区别在于:第一题中,序列只要下标不同即为不同,而本题需要数值不同。
我们可以这样考虑:第一题中我们用 \(dp_i\) 表示必选第 \(i\) 个数的答案。
而这道题中,我们为了避免值的重复,用 \(dp_i\) 表示必选值为 \(i\) 的数的答案。
所以我们时刻都要保证值 \(i\) 只被每种选过一次。
有一点需要注意,我们计算时还是要计算长度为 1 的子序列,因为之后可能会有数与其组成长度为 2 的,也会计入答案。最后再将长度为 1 的减掉就行了。
code
#include<bits/stdc++.h>
//#define int long long
#define lc f[p].l
#define rc f[p].r
#define endl putchar('\n')
#define psp putchar(' ')
using namespace std;
typedef unsigned long long ull;
typedef long long ll;
const int N=1e5+5;
const int mod=1e9+7;
int read(){
int x=0,f=1;
char c=getchar();
while(c<'0'||c>'9'){if(c=='-')f=-1;c=getchar();}
while(c>='0'&&c<='9')x=(x<<3)+(x<<1)+c-'0',c=getchar();
return x*f;
}
void print(int x){
if(x<0)putchar('-'),x=-x;
if(x<10){putchar(x+'0');return;}
print(x/10);
putchar(x%10+'0');
}
void putstr(string s){
for(int i=0;i<s.size();i++)putchar(s[i]);
}
int lowbit(int x){
return x&-x;
}
int n,m,k;
int T;
int a[N];
int bit[N];
void update(int x,int k){
while(x<=n){
bit[x]=(bit[x]+k)%mod;
x+=lowbit(x);
}
}
int query(int x){
int res=0;
while(x){
res=(res+bit[x])%mod;
x-=lowbit(x);
}
return res;
}
int dp[N];
int ans=0;
int b[N];
int deepseek(int x){
int l=1,r=m;
while(l<r){
int mid=l+r>>1;
if(b[mid]>=x)r=mid;
else l=mid+1;
}
return r;
}
signed main(){
//ios::sync_with_stdio(0);
n=read();
for(int i=1;i<=n;i++)a[i]=read();
for(int i=1;i<=n;i++)b[i]=a[i];
sort(b+1,b+1+n);
b[0]=b[1]-1;
for(int i=1;i<=n;i++)if(b[i]!=b[i-1])b[++m]=b[i];
for(int i=1;i<=n;i++)a[i]=deepseek(a[i]);
for(int i=1;i<=n;i++){
update(a[i],-(query(a[i])-query(a[i]-1)));//清空避免重复
update(a[i],query(a[i]-1)+1);//加一是加他自己
}
print(((query(m)-m)%mod+mod)%mod);
}
CF597C Subsequences
题意简述
求一个序列中长度为 \(k+1\) 的递增子序列数量。
思路
首先依旧考虑朴素思路。
\(dp_{i,j}\) 表示必选第 \(i\) 个数时,长度为 \(j\) 的递增子序列数量。
考虑转移,长度为 \(j\) 的子序列只能由长度为 \(j-1\) 的子序列加上一个数得到,而这个长度为 \(j-1\) 的子序列最后一个元素一定要小于 \(a_i\)。
有:
我们发现只会用到 \(j\) 和 \(j-1\),可以优化掉一个维度。
剩下的就是板子。(题目太良心了,不用离散化就可以写树状数组)
code
#include<bits/stdc++.h>
#define int long long
//#define lc p<<1
//#define rc p<<1|1
#define endl putchar('\n')
#define psp putchar(' ')
using namespace std;
typedef unsigned long long ull;
typedef long long ll;
const int N=1e5+5;
const int mod=1e9+7;
int read(){
int x=0,f=1;
char c=getchar();
while(c<'0'||c>'9'){if(c=='-')f=-1;c=getchar();}
while(c>='0'&&c<='9')x=(x<<3)+(x<<1)+c-'0',c=getchar();
return x*f;
}
void print(int x){
if(x<0)putchar('-'),x=-x;
if(x<10){putchar(x+'0');return;}
print(x/10);
putchar(x%10+'0');
}
void putstr(string s){
for(int i=0;i<s.size();i++)putchar(s[i]);
}
int lowbit(int x){
return x&-x;
}
int n,m,k;
int T;
int dp[N];
int a[N];
int bit[N];
void update(int x,int k){
while(x<=n){
bit[x]+=k;
x+=lowbit(x);
}
}
int query(int x){
int res=0;
while(x){
res+=bit[x];
x-=lowbit(x);
}
return res;
}
int ans=0;
signed main(){
n=read(),m=read();
for(int i=1;i<=n;i++)a[i]=read(),dp[i]=1;
for(int i=1;i<=m;i++){
for(int j=1;j<=n;j++)bit[j]=0;
for(int j=1;j<=n;j++){
update(a[j],dp[j]);
dp[j]=query(a[j]-1);
}
}
int ans=0;
for(int i=1;i<=n;i++)ans+=dp[i];
print(ans);
}
CF833B The Bakery
题意简述
将一个序列分成 \(k\) 组,每组的价值为组内数字的种类数,总价值为每组价值之和,求最大价值。
思路
依旧考虑朴素解法。
\(dp_{i,j}\) 表示前 \(i\) 个数,分成 \(j\) 组的最大价值。
有转移式子:
时间复杂度 \(O(N^2K)\),无法接受。
上面的转移式子与前文中第二种可以使用线段树优化的转移式子相似,再看看 \(val(k+1,i)\) 是否具有规律。
序列:\([3,2,1,3,2,4]\)
列出 \(i,k\) 以及 \(val(k+1,i)\) 的表格:
| \(i\) 的取值\ \(k\) 的取值 | 0 | 1 | 2 | 3 | 4 | 5 |
|---|---|---|---|---|---|---|
| 1 | 1 | |||||
| 2 | 2 | 1 | ||||
| 3 | 3 | 2 | 1 | |||
| 4 | 3 | 3 | 2 | 1 | ||
| 5 | 3 | 3 | 3 | 2 | 1 | |
| 6 | 4 | 4 | 4 | 3 | 2 | 1 |
通过观察,可以得出两个结论:
- 每一次最多增加 1。
- 每一次增加的区间连续。
再注意力惊人一下,就会发现:每一次增加的范围是 \(a_i\) 上一次出现的位置到 \(i-1\)。
这下就找到连续的区间了。
若我们先枚举 \(j\),则 \(j\) 对于内部为定值,所以我们可以线段树找到任意 \(i\) 时 \(val(k+1,i)\) 的最大值。
code
#include<bits/stdc++.h>
#define int long long
#define lc p<<1
#define rc p<<1|1
#define endl putchar('\n')
#define psp putchar(' ')
using namespace std;
typedef unsigned long long ull;
typedef long long ll;
const int N=35005;
int read(){
int x=0,f=1;
char c=getchar();
while(c<'0'||c>'9'){if(c=='-')f=-1;c=getchar();}
while(c>='0'&&c<='9')x=(x<<3)+(x<<1)+c-'0',c=getchar();
return x*f;
}
void print(int x){
if(x<0)putchar('-'),x=-x;
if(x<10){putchar(x+'0');return;}
print(x/10);
putchar(x%10+'0');
}
void putstr(string s){
for(int i=0;i<s.size();i++)putchar(s[i]);
}
int lowbit(int x){
return x&-x;
}
int n,m,k;
int T;
int a[N];
int pre[N];
struct node{
int l,r,mx,add;
}f[4*N];
int dp[55][N];
void build(int p,int l,int r,int last){
f[p]=node{l,r,dp[last][l],0};
if(l==r)return;
int mid=l+r>>1;
build(lc,l,mid,last);
build(rc,mid+1,r,last);
f[p].mx=max(f[lc].mx,f[rc].mx);
}
void down(int p){
if(f[p].add){
f[lc].add+=f[p].add;
f[rc].add+=f[p].add;
f[lc].mx+=f[p].add;
f[rc].mx+=f[p].add;
f[p].add=0;
}
}
void update(int p,int l,int r){
if(l<=f[p].l&&f[p].r<=r){
f[p].mx++;
f[p].add++;
return;
}
down(p);
int mid=f[p].l+f[p].r>>1;
if(l<=mid)update(lc,l,r);
if(r>mid)update(rc,l,r);
f[p].mx=max(f[lc].mx,f[rc].mx);
}
int query(int p,int l,int r){
if(l<=f[p].l&&f[p].r<=r)return f[p].mx;
int mid=f[p].l+f[p].r>>1;
int res=0;
down(p);
if(l<=mid)res=max(res,query(lc,l,r));
if(r>mid)res=max(res,query(rc,l,r));
return res;
}
signed main(){
//ios::sync_with_stdio(0);
n=read(),k=read();
for(int i=1;i<=n;i++)a[i]=read();
for(int i=1;i<=k;i++){
build(1,0,n,i-1);
for(int j=1;j<=n;j++)pre[a[j]]=0;
for(int j=1;j<=n;j++){
update(1,pre[a[j]],j-1);
pre[a[j]]=j;
dp[i][j]=query(1,i-1,j-1);
}
}
print(dp[k][n]);
}
AT_dp_w Intervals
题意简述
有 \(m\) 种得分方式,若字符串区间 \([l_i,r_i]\) 中有 1,则会获得 \(a_i\) 分,构造字符串求最大得分。
思路
朴素思路。
\(dp_i\) 表示考虑到前 \(i\) 为时的最大得分。
则有:
式子和模板一模一样,考虑找 \(dp_j+val(i)\) 的规律。
首先看到得分方式。
- \(j<le l_i\le r_i\le r\),这个得不到分。
- \(j<l_i\le i\le r_i\),这时有新的得分。
- \(l_i\le j\le i\le r_i\),这时得分不会新增,因为在 \(dp_j\) 时就被计算了。
综上,只有 \(i\) 在范围内,而 \(j\) 不在时才会得分。
这样我们就找到了在 \(i\) 固定时, \(j\) 在不同位置的得分。
用线段树优化即可。
code
#include<bits/stdc++.h>
#define int long long
#define lc p<<1
#define rc p<<1|1
#define endl putchar('\n')
#define psp putchar(' ')
using namespace std;
typedef unsigned long long ull;
typedef long long ll;
const int N=2e5+5;
int read(){
int x=0,f=1;
char c=getchar();
while(c<'0'||c>'9'){if(c=='-')f=-1;c=getchar();}
while(c>='0'&&c<='9')x=(x<<3)+(x<<1)+c-'0',c=getchar();
return x*f;
}
void print(int x){
if(x<0)putchar('-'),x=-x;
if(x<10){putchar(x+'0');return;}
print(x/10);
putchar(x%10+'0');
}
void putstr(string s){
for(int i=0;i<s.size();i++)putchar(s[i]);
}
int lowbit(int x){
return x&-x;
}
int n,m,k;
int T;
struct line{
int id,x;
};
vector<line>L[N],R[N];
struct node{
int l,r,ans,add;
}f[4*N];
void build(int p,int l,int r){
f[p]=node{l,r,0,0};
if(l==r)return;
int mid=l+r>>1;
build(lc,l,mid);
build(rc,mid+1,r);
}
void down(int p){
if(f[p].add){
f[lc].add+=f[p].add;
f[rc].add+=f[p].add;
f[lc].ans+=f[p].add;
f[rc].ans+=f[p].add;
f[p].add=0;
}
}
void update(int p,int l,int r,int k){
if(l<=f[p].l&&f[p].r<=r){
f[p].ans+=k;
f[p].add+=k;
return;
}
down(p);
int mid=f[p].l+f[p].r>>1;
if(l<=mid)update(lc,l,r,k);
if(r>mid)update(rc,l,r,k);
f[p].ans=max(f[lc].ans,f[rc].ans);
}
int query(int p,int l,int r){
if(l<=f[p].l&&f[p].r<=r){
return f[p].ans;
}
down(p);
int mid=f[p].l+f[p].r>>1;
int res=-1e18;
if(l<=mid)res=max(res,query(lc,l,r));
if(r>mid)res=max(res,query(rc,l,r));
return res;
}
int dp[N];
int ans=0;
signed main(){
//ios::sync_with_stdio(0);
n=read(),m=read();
for(int i=1;i<=m;i++){
int l=read(),r=read(),x=read();
L[l].push_back({l,x});
R[r+1].push_back({l,-x});
}
build(1,0,n);
for(int i=1;i<=n;i++){
for(line v:L[i])update(1,0,v.id-1,v.x);//i 在 l[i]~r[i] 的范围内,j<l[i] 时即可得分。
for(line v:R[i])update(1,0,v.id-1,v.x);//i 超过了 r[i],不可能以这个得分方式得到新的分。
dp[i]=query(1,0,i-1);
ans=max(ans,dp[i]);
update(1,i,i,dp[i]);
}
print(ans);
}
浙公网安备 33010602011771号