subsequnce----dp

subsequence

题意:给长度为\(n\), \(m\)的字符串\(s\), \(t\), 字符串由0~9的数字组成,问在十进制意义下\(s\)中比\(t\)串大的子序列个数。

\(m\leq n \leq{3000}\).

题解:考虑两种不同情况:子序列长度等于\(t\)串以及子序列长度大于\(t\)串。用\(len[i][j]\)维护\(s\)串中第\(i\)位以前长度为\(j\)的合法串(无前导零)个数,那么长度大于\(t\)串的个数为\(\sum_{i=m+1}^{n} len[n][i]\).

\(dp1[i][j]\)维护\(s\)串中第\(i\)位以前长度为\(j\)且严格大于\(t\)串中前\(j\)位的子序列个数, \(dp2[i][j]\)维护\(s\)串中第\(i\)位以前长度为\(j\)且大于等于\(t\)串前\(j\)位的子序列个数,很容易由\(s[i]\)\(t[j]\)的大小关系得出一系列转移方程。 于是长度等于\(t\)串的个数为\(dp1[n][m]\).

代码:

#include <bits/stdc++.h>
using namespace std;
const int mod = 998244353;
char a[3005], b[3005];
int n, m;
int dp1[3005][3005], len[3005][3005], dp2[3005][3005];
int main() {
    int T;
    cin >> T;
    while(T--) {
        scanf("%d%d", &n, &m);
        int ans1 = 0, ans2 = 0;
        int maxx = max(n, m);
        for(int i = 0; i <= maxx + 2; i++) for(int j = 0; j <= maxx + 2; j++) dp1[i][j] = dp2[i][j] = 0, len[i][j] = 0;
        scanf("%s%s", a + 1, b + 1);
        dp1[0][0] = 0;
        dp2[0][0] = 1;
        len[0][0] = 1;
        for(int i = 1; i <= n; i++) {
            len[i][0] = 1;
            for(int j = 1; j <= i; j++) len[i][j] = (len[i - 1][j - 1] + len[i - 1][j]) % mod;
            if(a[i] == '0') len[i][1] = len[i - 1][1];
            dp2[i][0] = 1;
            for(int j = 1; j <= i; j++) {
                dp1[i][j] = dp1[i - 1][j];
                dp2[i][j] = dp2[i - 1][j];
                if(a[i] <= b[j]) {
                    dp1[i][j] = (dp1[i][j] + dp1[i - 1][j - 1]) % mod;
                    if(a[i] == b[j]) dp2[i][j] = (dp2[i][j] + dp2[i - 1][j - 1]) % mod;
                    else dp2[i][j] = (dp2[i][j] + dp1[i - 1][j - 1]) % mod;
                }
                if(a[i] > b[j]) {
                    dp1[i][j] = (dp1[i][j] + dp2[i - 1][j - 1]) % mod;
                    dp2[i][j] = (dp2[i][j] + dp2[i - 1][j - 1]) % mod;
                }
                //dp2[i][j] += dp1[i - 1][j - 1];
            }
        }
        for(int i = m + 1; i <= n; i++) ans1 = (ans1 + len[n][i]) % mod;
        ans2 = dp1[n][m];
        int ans = (ans1 + ans2) % mod;
        printf("%d\n", ans);
    }
    return 0;
}
posted @ 2019-08-01 20:53  rain_star  阅读(192)  评论(0编辑  收藏  举报