[ARC 058 - E]Iroha and Haiku
解题步骤
首先可以发现题目范围非常小,尤其是\(X,Y,Z\),所以考虑类似状压、数位dp、双向搜索等算法。
官方题解中给的是数位dp,那我这里就讲讲状压了
对于\(N \leq 40\),很明显不能对其进行状压并且没意义,那么对于\(X,Y,Z\)呢?因为题目要求连续一段数满足要求,且\(X+Y+Z \leq 17,a_i \geq 1\),故只用考虑连续一段长度不超过17的的区间是否满足要求即可。
但是对于具体数值怎么状压呢?我们发现因为和不超过\(X+Y+Z\),所以我们可以对当前第\(i\)位前一段和不超过\(X+Y+Z\)的区间的每一个数\(a_j\)通过\(2^{a_j}\)表示,同时按顺序依次乘\(2^{a_i}\),如此既没改变相对顺序,又表示出数值,实现如下:
int tot = (1 << X + Y + Z) - 1;
int s = ((msk << j) | (1 << j - 1)) & tot;
/*
tot为全集
j为枚举当前位填入的元素
msk即为前一个状态元素集合
s则为加上当前元素j后的元素集合
*/
到这里,状压思路已经很清晰了,最后一个难点为判定满足条件。因为我们对于状压状态,发现一段前缀的和在不断位移的操作下变成了某一个元素在\(msk\)对应的单个\(1\),故目标可表示为(1<<X-1)|(1<<X+Y-1)|(1<<X+Y+Z-1),对于当前集合判断包不包含目标集合,如果包含,满足条件,反之不满足。
int targ = (1 << X - 1) | (1 << X + Y - 1) | (1 << X + Y + Z - 1), tot = (1 << X + Y + Z) - 1;
int s = ((msk << j) | (1 << j - 1)) & tot;
if ((s & targ) != targ)
    dp[i][s] = (dp[i][s] + dp[i - 1][msk]) % mod;
/*
targ为目标集合
dp[i][msk]表示考虑第i位,前一段和不超过X+Y+Z的区间为msk的情况下,有多少种情况不满足条件
*/
最后做个解释,为什么dp[i][msk]要表示不满足条件的区间。因为对于初始状态,不满足条件的区间都由dp[0][0]演化而来,而满足条件的区间千变万化,可以从很多种情况转移,不好把握。
最终实现
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N = 45, TN = 17;
const LL mod = 1e9 + 7;
LL dp[N][1 << TN];
int X, Y, Z, n;
LL ans;
inline LL power(LL a, LL k) //可要可不要,直接O(N)把ans乘一遍也只是40
{
    LL ret = 1;
    while (k)
    {
        if (k & 1)
            ret = ret * a % mod;
        a = a * a % mod;
        k >>= 1;
    }
    return ret;
}
int main()
{
    scanf("%d%d%d%d", &n, &X, &Y, &Z);
    dp[0][0] = 1;
    ans = power(10ll, n);
    int targ = (1 << X - 1) | (1 << X + Y - 1) | (1 << X + Y + Z - 1), tot = (1 << X + Y + Z) - 1;
    //targ:目标集合 tot:全集
    for (int i = 1; i <= n; i++)
        for (int msk = 0; msk <= tot; msk++)
            for (int j = 1; j <= 10; j++)
            {
                int s = ((msk << j) | (1 << j - 1)) & tot;
                //加入当前元素后集合
                if ((s & targ) != targ) //不满足条件
                    dp[i][s] = (dp[i][s] + dp[i - 1][msk]) % mod;
            }
    //dp第一层循环一定是1~n,因为第i位是该dp的一个阶段,所有第i层的状态都依赖于上一层
    for (int msk = 0; msk <= tot; msk++)
        ans = (ans - dp[n][msk] + mod) % mod;
    printf("%lld", ans);
    return 0;
}

                
            
        
浙公网安备 33010602011771号