题解: Luogu P8894 「UOI-R1」求和

题目链接:link

前言

我的一个学长在一次比赛中出了这道题,然后,我就把这道题切了
其实这道题还是比较简单的,然后我就介绍一下我的比赛时的思路和做法

30分做法

根据标签我们可以知道,这是一道dp题
我们先设\(dp[i][j]\)表示前\(i\)个区间中最大值为\(j\)时的次数
那么转移时就枚举之前的最大值:\(k\)
那么状态转移方程就是:
$dp[i][\max(j,k)]= \sum dp[i-1][k] $
然后就可以得到这样的一个暴力做法:

Code

#include<bits/stdc++.h>
using namespace std;
const int N=5001,mod=998244353;

int dp[N][N];
int n,m,ans,l,r;

int main()
{
	scanf("%d",&n);
	dp[0][0]=1;
	for (int i=1;i<=n;i++)
	{
		scanf("%d%d",&l,&r);
		for (int j=l;j<=r;j++)
			for (int k=0;k<=m;k++)
				dp[i][max(j,k)]+=dp[i-1][k],dp[i][max(j,k)]%=mod;
		m=max(m,r);
	}
	for (int i=1;i<=m;i++)
		ans+=1ll*dp[n][i]*i%mod,ans%=mod;
	printf("%d",ans);
	return 0;
}

再优化时间

那么这种做法的时间复杂度大约是\(O(nm^2)\)
这种方法肯定会超时.所以考虑优化时间
那么转移的时候我们就只枚举一个j:表示i个区间中的最大值
然后再分类讨论:
\(r<j\)时,最大值j就一定是从前\(i-1\)个区间来的,所以\(dp[i][j]=dp[i-1][j]\times(r-l+1)\)
否则,最大值就也有可能是从第i个区间来的,那么\(dp[i][j]=dp[i-1][j]\times(j-l+1)+\sum_{k=1}^{j-1} dp[i-1][k]\)
状态转移方程就是:

\[dp[i][j]= \begin{cases} dp[i-1][j]\times(r-l+1) & r<j \\ dp[i-1][j]\times(j-r+1)+\sum_{k=1}^{j-1} dp[i-1][k] & r\ge j \\ \end{cases} \]

\(\sum_{k=1}^{j-1} dp[i-1][k]\)时,显然就可以用前缀和来优化
那么这样时间复杂度就变成了\(O(nm)\),然后这样就可以AC了

Code

#include<bits/stdc++.h>
using namespace std;
const int N=5001,mod=998244353;

int dp[N][N];
int n,m,ans,s[N];

int main()
{
	scanf("%d",&n);
	dp[0][0]=1;
	for (int i=1,l,r;i<=n;i++)
	{	
		scanf("%d%d",&l,&r);
		m=max(m,r);
		
		memcpy(s,dp[i-1],sizeof dp[i-1]);
		for (int i=1;i<=m;i++)
			s[i]+=s[i-1],s[i]%=mod;
		
		for (int j=l;j<=m;j++)
		{
			dp[i][j]=1ll*dp[i-1][j]*(min(j,r)-l+1)%mod;
			if (r>=j) dp[i][j]+=s[j-1],dp[i][j]%=mod;
		}
	}
	for (int i=1;i<=m;i++)
		ans+=1ll*dp[n][i]*i%mod,ans%=mod;
	printf("%d",ans);
	return 0;
}

继续优化:

既然都求了s数组了,那么之前的状态就可以用s数组来求出来,那么dp数组就可以只开一维了
然后还可以再优化时间:每次再存一个ml:每个区间的l的最大值
但是需要注意:
优化成一维后,dp数组某些不符合条件的位置,还会再存之前位置的次数
所以每次转移前,需要先清空一下dp数组

Code

#include<bits/stdc++.h>
using namespace std;
const int N=5001,mod=998244353;

int dp[N];
int n,ml,mr,ans,s[N];

int main()
{
	scanf("%d",&n);
	dp[0]=1;
	for (int i=0,l,r;i<n;i++)
	{
		scanf("%d%d",&l,&r);
		ml=max(ml,l),mr=max(mr,r);
		memcpy(s,dp,sizeof dp);
		for (int j=1;j<=mr;j++) s[j]+=s[j-1],s[j]%=mod;
		memset(dp,0,sizeof dp);
		for (int j=ml;j<=mr;j++)
		{
			dp[j]=1ll*((s[j]-s[j-1])%mod+mod)%mod*(min(j,r)-l+1)%mod;
			if (r>=j) dp[j]+=s[j-1],dp[j]%=mod;
		}
	}
	for (int i=ml;i<=mr;i++)
		ans+=1ll*dp[i]*i%mod,ans%=mod;
	printf("%d",ans);
	return 0;
}
posted @ 2023-01-07 18:45  懵逼自动机  阅读(82)  评论(5)    收藏  举报