[HAOI2010]最长公共子序列
题目大意
给定两个字符串 \(a,b\),求出 \(a,b\) 的最长公共子序列长度以及 \(a,b\) 的公共子序列个数。
题目分析
先考虑第一问,令 \(dp1[i][j]\) 表示 \(a\) 前 \(i\) 和 \(b\) 前 \(j\) 个字符的最长公共子序列长度。
若 \(a[i]=b[j]\),即当前字符相同,那么一定要选这儿,\(dp1[i][j]=dp1[i-1][j-1]+1\)。
否则 \(dp1[i][j]=\max(dp1[i-1][j],dp1[i][j-1])\)。
参考代码:
for (int i = 1;i <= n; ++ i){
for (int j = 1;j <= m; ++ j){
dp1[i][j] = a[i] == b[j] ? dp1[i - 1][j - 1] + 1 : max(dp1[i - 1][j],dp1[i][j - 1]);
}
}
printf("%d\n",dp1[n][m]);
难点是第二问:求 \(a,b\) 的公共子序列个数。
令 \(dp2[i][j]\) 表示 \(a\) 前 \(i\) 和 \(b\) 前 \(j\) 个字符的公共子序列个数。
这里用 \(\operatorname{lcm}(i,j)\) 表示 \(a\) 的前 \(i\) 个和 \(b\) 的前 \(j\) 个中的最长公共子序列。
我们分情况讨论:
-
\(a[i]\in \operatorname{lcm}(i,j),b[j]\in \operatorname{lcm}(i,j)\)
-
\(a[i]\in \operatorname{lcm}(i,j),b[j]\not \in \operatorname{lcm}(i,j)\)
-
\(a[i]\not\in \operatorname{lcm}(i,j),b[j]\in \operatorname{lcm}(i,j)\)
-
\(a[i]\not\in \operatorname{lcm}(i,j),b[j]\not\in\operatorname{lcm}(i,j)\)
若 \(a[i]=b[j]\),则 \(dp1[i][j]=dp1[i-1][j-1]+1\),即情况 \(1\)。
于是有 \(dp2[i][j]\gets dp2[i][j]+dp2[i-1][j-1]\)。
若 \(dp1[i][j]=dp1[i][j-1]\),此时对应情况 \(2\) 和 \(4\)。
于是有 \(dp2[i][j]\gets dp2[i][j]+dp2[i][j-1]\)。
若 \(dp1[i][j]=dp1[i-1][j]\),此时对应情况 \(3\) 和 \(4\)。
于是有 \(dp2[i][j]\gets dp2[i][j]+dp2[i-1][j]\)。
若 \(dp1[i][j]=dp1[i-1][j-1]\),此时对应情况 \(4\)。满足该条件,则必然满足前两种情况,根据容斥原理,于是有 \(dp2[i][j]\gets dp2[i][j]-dp2[i-1][j-1]\)。
令 \(n\) 为 \(a\) 的长度,\(m\) 为 \(b\) 的长度,则:
时间复杂度为 \(O(nm)\),空间复杂度为 \(O(nm)\),空间只有 \(\verb!125MB!\),并且发现转移之和前一种情况有关,使用滚动数组优化。
注意事项:
-
不开
_________
见_______
。 -
方案数取差时注意负数情况。
代码
//2021/12/9
#define _CRT_SECURE_NO_WARNINGS
#include <iostream>
#include <cstdio>
#include <climits>//need "INT_MAX","INT_MIN"
#include <algorithm>
#include <cstring>
#define int long long
#define enter() putchar(10)
#define debug(c,que) cerr<<#c<<" = "<<c<<que
#define cek(c) puts(c)
#define blow(arr,st,ed,w) for(register int i=(st);i<=(ed);i++)cout<<arr[i]<<w;
#define speed_up() cin.tie(0),cout.tie(0)
#define endl "\n"
#define Input_Int(n,a) for(register int i=1;i<=n;i++)scanf("%d",a+i);
#define Input_Long(n,a) for(register long long i=1;i<=n;i++)scanf("%lld",a+i);
#define mst(a,k) memset(a,k,sizeof(a))
namespace Newstd
{
inline int read()
{
int x=0,k=1;
char ch=getchar();
while(ch<'0' || ch>'9')
{
if(ch=='-')
{
k=-1;
}
ch=getchar();
}
while(ch>='0' && ch<='9')
{
x=(x<<1)+(x<<3)+ch-'0';
ch=getchar();
}
return x*k;
}
inline void write(int x)
{
if(x<0)
{
putchar('-');
x=-x;
}
if(x>9)
{
write(x/10);
}
putchar(x%10+'0');
}
}
using namespace Newstd;
using namespace std;
const int mod=1e8,ma=5005;
char a[ma],b[ma];
int dp1[2][ma],dp2[2][ma];
//dp2[i][j]:a 的前 i 个,b 的前 j 个中所有最长公共子序列的个数
int n,m;
#undef int
int main(void)
{
#define int long long
scanf("%s%s",a+1,b+1);
n=strlen(a+1)-1,m=strlen(b+1)-1;
int st=0;
for(register int i=0;i<=m;i++)
{
dp2[0][i]=1;
}
dp2[1][0]=1;
for(register int i=1;i<=n;i++)
{
for(register int j=1;j<=m;j++)
{
dp2[st^1LL][j]=0;
if(a[i]==b[j])
{
dp1[st^1LL][j]=dp1[st][j-1]+1;
}
else
{
dp1[st^1LL][j]=max(dp1[st][j],dp1[st^1LL][j-1]);
}
if(a[i]==b[j])
{
dp2[st^1LL][j]+=dp2[st][j-1];//dp2[i][j]+=dp2[i-1][j-1]
}
if(dp1[st^1LL][j]==dp1[st^1LL][j-1])//dp1[i][j]==dp1[i][j-1]
{
dp2[st^1LL][j]+=dp2[st^1LL][j-1];//dp2[i][j]+=dp2[i][j-1]
}
if(dp1[st^1LL][j]==dp1[st][j])//dp1[i][j]==dp1[i-1][j]
{
dp2[st^1LL][j]+=dp2[st][j];//dp2[i][j]+=dp2[i-1][j]
}
if(dp1[st^1LL][j]==dp1[st][j-1])//dp1[i][j]==dp1[i-1][j-1]
{
dp2[st^1LL][j]-=dp2[st][j-1];//dp2[i][j]-=dp2[i-1][j-1]
}
dp2[st^1LL][j]=(dp2[st^1LL][j]+mod)%mod;
}
st^=1LL;
}
printf("%lld\n%lld\n",dp1[st][m],dp2[st][m]);
return 0;
}