【LeetCode 1220】统计元音字母序列的数目

题目描述

原题链接: LeetCode.1220 统计元音字母序列的数目

解题思路

  • 定义DP数组dp[i][j]含义为长度为i+1且以j字符结尾的字符串有多少个, j从0到4依次代表('a', 'e', 'i', 'o', 'u')这5个元音字符, dp[0][0~4]长度为1时的初始个数都为1;
  • dp[i][j]对应字符串末尾字符已经由j确定, 对应个数就要看dp[i-1]中有哪些字符结尾的字符串能追加1个j字符得到, 按照题意总结递推公式:
    • 只有['e', 'i', 'u']的后面能跟着'a', 所以\(dp[i][0] = dp[i-1][1] + dp[i-1][2] + dp[i-1][4]\)
    • 只有['a', 'i']的后面能跟着'e', 所以\(dp[i][1] = dp[i-1][0] + dp[i-1][2]\)
    • 只有['e', 'o']的后面能跟着'i', 所以\(dp[i][2] = dp[i-1][1] + dp[i-1][3]\)
    • 只有['i']的后面能跟着'o', 所以\(dp[i][3] = dp[i-1][2]\)
    • 只有['i', 'o']的后面能跟着'u', 所以\(dp[i][4] = dp[i-1][2] + dp[i-1][3]\)
  • 上述5维DP解法的时间复杂度仍然有\(O(n)\)规模, 对于这种公式固定的线性递推可以借助矩阵快速幂技巧得到更快速的解法:
    • 设dp[i]对应5个序列数目值的矩阵为\(A_i=\begin{pmatrix}dp_i[0]&dp_i[1]&dp_i[2]&dp_i[3]&dp_i[4]\end{pmatrix}\), 存在矩阵B满足\(A_{i-1} * B = A_i\), 求dp[i]的值就转变为由\(A_0 * B^{i-1}\)也就是求B矩阵幂的问题;
    • 按行或者按列都能快速确定本题中的递推矩阵为\(B = \begin{pmatrix}0&1&0&0&0\\1&0&1&0&0\\1&1&0&1&1\\0&0&1&0&1\\1&0&0&0&0\end{pmatrix}\)
    • 由题意可以按行轻松确定递推矩阵的值:
      • 'a'后面只能跟着'e'可以确定第一行只有第二列为1;
      • 'e'后面只能跟着'a'或'i'可以确定第二行只有第一和第三列为1;
      • 'i'后面只禁止跟着'i'可以确定第三行除了第二列全是1;
      • 'o'后面只能跟着'i'或'u'可以确定第四行只有第三和第五列为1;
      • 'u'后面只能跟着'a'可以确定第五行只有第一列为1。
    • 由上述k维1阶递推公式也可以按列直接写出递推矩阵,每列的值都对应公式中的常数项1或0, 这里直接给出矩阵值对照理解一下:\(\begin{pmatrix} 0&1&0&0&0\\ 1&0&1&0&0\\ 1&1&0&1&1\\ 0&0&1&0&1\\ 1&0&0&0&0 \end{pmatrix}\)
    • 时间复杂度可以优化到\(O(5^3 * log_2n)\)

解题代码

  • 朴素5维1阶动态规划解法

      final int MOD = 1_000_000_007;
    
      /**
       * 按照规则总结5维DP的递推公式求解
       * 执行用时: 19 ms , 在所有 Java 提交中击败了 41.18% 的用户
       * 内存消耗: 43.46 MB , 在所有 Java 提交中击败了 36.98% 的用户
       */
      public int countVowelPermutation(int n) {
          // dp[i][j]: 表示长度为i+1且以j字母结尾的字符串的个数, 其中j=0, 1, 2, 3, 4依次表示'a', 'e', 'i', 'o', 'u'字母
          int[][] dp = new int[n][5];
          Arrays.fill(dp[0], 1);
          /*
          * 根据题意中规则, 总结这5个元音字母结尾时合规的上一位字母有哪些
          * dp[i]['a']: dp[i-1]['e'] + dp[i-1]['i'] + dp[i-1]['u']
          * dp[i]['e']: dp[i-1]['a'] + dp[i-1]['i']
          * dp[i]['i']: dp[i-1]['e'] + dp[i-1]['o']
          * dp[i]['o']: dp[i-1]['i']
          * dp[i]['u']: dp[i-1]['i'] + dp[i-1]['o']
          */
          for (int i = 1; i < n; i++) {
              int[] pre = dp[i - 1];
              dp[i][0] = ((pre[1] + pre[2]) % MOD + pre[4]) % MOD;
              dp[i][1] = (pre[0] + pre[2]) % MOD;
              dp[i][2] = (pre[1] + pre[3]) % MOD;
              dp[i][3] = pre[2];
              dp[i][4] = (pre[2] + pre[3]) % MOD;
          }
          int res = 0;
          for (int i = 0; i < 5; i++) {
              res = (res + dp[n - 1][i]) % MOD;
          }
          return res;
      }
    
  • 时间复杂度最优的矩阵快速幂解法

      final int MOD = 1_000_000_007;
    
      /**
       * 矩阵快速幂加速解法
       * 执行用时: 1 ms , 在所有 Java 提交中击败了 100.00% 的用户
       * 内存消耗: 39.81 MB , 在所有 Java 提交中击败了 58.06% 的用户
       */
      public int countVowelPermutation(int n) {
          // 长度为1的元音字母序列个数各有1种
          int[][] start = {{1, 1, 1, 1, 1}};
          // 递推矩阵可以直接由题意按行确认, 每个字母后面允许跟着的字母对应列的值为1组成一行
          int[][] transfer = {
                  {0, 1, 0, 0, 0},
                  {1, 0, 1, 0, 0},
                  {1, 1, 0, 1, 1},
                  {0, 0, 1, 0, 1},
                  {1, 0, 0, 0, 0}};
          int[][] resMatrix = multiply(start, power(transfer, n - 1, MOD), MOD);
          long ans = 0;
          for (int num : resMatrix[0]) {
              ans = (ans + num) % MOD;
          }
          return (int) ans;
      }
    
      public int[][] power(int[][] base, int n, int mod) {
          int row = base.length;
          int[][] res = new int[row][row];
          for (int i = 0; i < row; i++) {
              res[i][i] = 1;
          }
          while (n > 0) {
              if ((n & 1) == 1) {
                  res = multiply(res, base, mod);
              }
              base = multiply(base, base, mod);
              n >>= 1;
          }
          return res;
      }
    
      public int[][] multiply(int[][] a, int[][] b, int mod) {
          int m = a.length, n = b[0].length;
          int[][] res = new int[m][n];
          int k = b.length;
          for (int aRow = 0; aRow < m; aRow++) {
              for (int bCol = 0; bCol < n; bCol++) {
                  long sum = 0;
                  for (int i = 0; i < k; i++) {
                      sum = (sum + (long) a[aRow][i] * b[i][bCol]) % mod;
                  }
                  res[aRow][bCol] = (int) sum;
              }
          }
          return res;
      }
    

参考链接:
左程云大佬的B站算法讲解视频01:36:00开始

posted on 2024-03-23 16:49  真不如烂笔头  阅读(2)  评论(0编辑  收藏  举报

导航