P9387 [THUPC 2023 决赛] 巧克力 题解

这篇题解会只讲怎么 dp,所以我们这里跳过博弈论的部分。

Let's rephrase the problem statement as follows:

给定 \(n,m\),设 \(x=1\oplus 2\oplus\cdots\oplus n\oplus m\)。求有多少个有序三元组 \((a,b,c)\) 满足:

  • \(a+b+c\le n\)\(a+b+c=m\)(如果都满足需要算两遍)。
  • \((a+b+c)\oplus a\oplus c=x\)
  • \(a,c\ge 0,b\gt 0\)

答案对 \(10^9+7\) 取模。

注意到 \(\forall y,y\oplus(y+1)\oplus(y+2)\oplus(y+3)=0\),这样 \(x\) 就可以快速计算了。

首先我们显然只需要解决 \(a+b+c\le n\) 的问题。

对于这种涉及到二进制的计数 dp。我们一般采取 从低位到高位 的顺序 dp,并同时记录进了多少位。另一个经典例题:NOIP 数列。

\(f_{i,j,k,l}\) 表示:考虑了 \(a,b,c\) 的最低 \(i\) 位,第 \(i\) 位往第 \(i+1\) 位进位了 \(j\)\(b\) 的低 \(i\) 位是否是 \(0\)\(a+b+c\) 的低 \(i\) 位是否大于 \(n\) 的低 \(i\) 位。转移时枚举 \(a,b,c\) 的第 \(i\) 位是什么即可。答案是 \(f_{60,0,1,0}\)

代码:

auto dp = [&](ll n) {
  if (n < 0) return 0;
  memset(f, 0, sizeof(f));
  f[0][0][0][0] = 1;
  for (int i = 0; i < N; i++)
    for (int j = 0; j <= 2; j++)
      for (int k = 0; k <= 1; k++)
        for (int l = 0; l <= 1; l++) {
          for (int a = 0; a <= 1; a++)
            for (int b = 0; b <= 1; b++)
              for (int c = 0; c <= 1; c++) {
                int s = (j + a + b + c) & 1; // 计算 (a + b + c) 的第 i 位
                if ((s ^ a ^ c) != ((x >> i) & 1)) continue;
                int nl = (s > ((n >> i) & 1)) || (l && (s == ((n >> i) & 1)));
                add(f[i + 1][(j + a + b + c) / 2][k || b][nl], f[i][j][k][l]);
              }
        }
  return f[N][0][1][0];
};

但是这么写会 T 掉,我们需要将枚举 \(a,b,c\) 的循环展开。这里可以写一个生成代码的程序,然后手动合并能合并的代码。结果如下(没有任何可读性……所以我把两份代码都放到这里了):

auto dp = [&](ll n) {
  if (n < 0) return 0;
  memset(f, 0, sizeof(f));
  f[0][0][0][0] = 1;
  for (int i = 0; i < N; i++)
    for (int j = 0; j <= 2; j++)
      for (int k = 0; k <= 1; k++)
        for (int l = 0; l <= 1; l++) {
          int s, nl;

          s = (j + 0) & 1;
          nl = (s > ((n >> i) & 1)) || (l && (s == ((n >> i) & 1)));
          if ((s ^ 0) == ((x >> i) & 1)) {
            add(f[i + 1][(j + 0) >> 1][k][nl], f[i][j][k][l]);
            add(f[i + 1][(j + 2) >> 1][k][nl], f[i][j][k][l]);
          } else {
            add(f[i + 1][(j + 2) >> 1][1][nl], f[i][j][k][l]);
            add(f[i + 1][(j + 2) >> 1][1][nl], f[i][j][k][l]);
          }

          s = (j + 1) & 1;
          nl = (s > ((n >> i) & 1)) || (l && (s == ((n >> i) & 1)));
          if ((s ^ 0) == ((x >> i) & 1)) {
            add(f[i + 1][(j + 1) >> 1][1][nl], f[i][j][k][l]);
            add(f[i + 1][(j + 3) >> 1][1][nl], f[i][j][k][l]);
          } else {
            add(f[i + 1][(j + 1) >> 1][k][nl], f[i][j][k][l]);
            add(f[i + 1][(j + 1) >> 1][k][nl], f[i][j][k][l]);
          }
        }
  return f[N][0][1][0];
};

时间复杂度 \(\mathcal O(\log n)\)

posted @ 2023-07-29 19:38  registerGen  阅读(59)  评论(0编辑  收藏  举报