插头DP

前言

图片来源:我老师的PDF因为我不会画图

前排膜拜一波(

插头dp虽然模版难度就是黑,但是我认为并不难。我认为dp的难度排序:

ddp(动态DP,P4719)>分治dp>插头dp>决策单调性优化dp>网络流(类dp)>斜率优化dp>暴力dp

模版/推荐题目

P5056 【模版】插头dp

P2289 邮递员

P2337 喵星人的入侵

目的

给定一个点集$S$,求图上的一条线,要求包含$S$里包括的所有点,并且不包含图上非$S$集的点,且不能重复经过一个点。

可能不是很好理解(**不会看题吗),比如下图两条线都满足条件(白色的格子表示$S$集)。

问能画出来多少种线。上图情况答案就是2,因为除了这两种画不出别的了。

实现

如果没做过轮廓线DP,可以先做一下P4363,那是很基础的一道轮廓线DP题。

勾勒一个轮廓线,意义与P4363相同,即为已经dp过的点的下轮廓,并给轮廓线的每条边加上插头。

上图为正在计算点(3,3)的时候的轮廓线。

但是只记录是否空插头还是不够的,因为一条线既可以是区间左端,也可以是右端。

如何存储:状压括号序列即可。0表示空插头,1表示向上,2表示向下。

正确性证明:很显然,如果括号序列是交叉的,形如$\color{red}(\color{green}(\color{red})\color{green})$,那么它是一个不合法的序列(重复经过)。

例:

上图插头们的括号序列是$1120212$。

为了方便,将上图挨着正在考虑的点$5,5$的两个插头称作左插头(插头5)和上插头(插头6)。

接下来,对于点$(i,j)$,分类讨论情况。

1. 障碍:别动,跳过就行。
2. 上插头和左插头都是空:因为任何一个点都要经过,所以只能向下向右都连边。
3. 只有上插头空:下和右随便连边,因为已经连了一个了。连下、连右、连下&右都合法。(当然都不连不合法)
4. 只有左插头空:同3,只是要改方向。
5. 左、上都是起始(状态都是1):因为不能重复,这时候只能合并插头。但是合并后其他插头可能会有改变,具体实现详见代码。

6. 左、上都是终结(状态都是2):同5,只是要改方向。
7. 回路:如果已经走完了(即回路形成点在$(n,m)$),那么加进答案,否则丢弃。

到这里,插头dp就已经可以实现了。

优化

1. 一次转移显然只和上一次的答案有关,因此可以用滚动数组优化空间。
2. 优化两个状态之间的转移:哈希表。

代码

另外,哈希建议一个质数,可以减少冲突。1e6左右我找的是1e6+3。

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int mod = 1e6 + 3;
int n, m, ex, ey, bits[15], state[mod + 5][2];
bool mp[15][15];
ll f[mod + 5][2], ans;
int head[mod + 5], to[mod + 5], nxt[mod + 5], sz[2], qwq, tot;
inline void link(int u, int v)
{
    to[tot] = v;
    nxt[tot] = head[u];
    head[u] = tot++;
}
inline void add(int x, ll k)
{
    int key = x % mod;
    for (int i = head[key]; ~i; i = nxt[i])
        if (state[to[i]][qwq] == x)
        {
            f[to[i]][qwq] += k;
            return;
        }
    state[++sz[qwq]][qwq] = x;
    f[sz[qwq]][qwq] = k;
    link(key, sz[qwq]);
}
int main()
{
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    cin >> n >> m;
    for (int i = 1; i <= n; i++)
        for (int j = 1; j <= m; j++)
        {
            char c;
            cin >> c;
            if (c == '.')
            {
                mp[i][j] = true;
                ex = i, ey = j;
            }
        }
    for (int i = 1; i <= 12; i++)
        bits[i] = i << 1;
    sz[qwq] = 1;
    f[1][qwq] = 1, state[1][qwq] = 0;
    for (int i = 1; i <= n; i++)
    {
        for (int j = 1; j <= sz[qwq]; j++)
            state[j][qwq] <<= 2;
        for (int j = 1; j <= m; j++)
        {
            tot = 0;
            memset(head, -1, sizeof(head));
            qwq ^= 1;
            sz[qwq] = 0;
            for (int k = 1; k <= sz[qwq ^ 1]; k++)
            {
                int stt = state[k][qwq ^ 1];
                int up = (stt >> bits[j]) % 4, left = (stt >> bits[j - 1]) % 4;
                ll val = f[k][qwq ^ 1];
                if (!mp[i][j])
                    add(stt, val);
                else if (!up && !left)
                {
                    if (mp[i + 1][j] && mp[i][j + 1])
                        add(stt | 1 << bits[j - 1] | 2 << bits[j], val);
                }
                else if (left && !up)
                {
                    if (mp[i + 1][j])
                        add(stt, val);
                    if (mp[i][j + 1])
                        add(stt - left * (1 << bits[j - 1]) + left * (1 << bits[j]), val);
                }
                else if (!left && up)
                {
                    if (mp[i][j + 1])
                        add(stt, val);
                    if (mp[i + 1][j])
                        add(stt - up * (1 << bits[j]) + up * (1 << bits[j - 1]), val);
                }
                else if (left == 1 && up == 1)
                {
                    int cnt = 1;
                    for (int p = j + 1; p <= m; p++)
                    {
                        if ((stt >> bits[p]) % 4 == 1) // left plug
                            cnt++;
                        else if ((stt >> bits[p]) % 4 == 2) // right plug
                            cnt--;
                        if (!cnt) // matched
                        {
                            add(stt - (1 << bits[p]) - (1 << bits[j]) - (1 << bits[j - 1]), val);
                            break;
                        }
                    }
                }
                else if (left == 2 && up == 2)
                {
                    int cnt = 1;
                    for (int p = j - 2; p >= 0; p--)
                    {
                        if ((stt >> bits[p]) % 4 == 1) // left plug
                            cnt--;
                        if ((stt >> bits[p]) % 4 == 2) // right plug
                            cnt++;
                        if (!cnt) // matched
                        {
                            add(stt - (2 << bits[j]) - (2 << bits[j - 1]) + (1 << bits[p]), val);
                            break;
                        }
                    }
                }
                else if (left == 2 && up == 1)
                    add(stt ^ 2 << bits[j - 1] ^ 1 << bits[j], val);
                else if (left == 1 && up == 2 && i == ex && j == ey)
                    ans += val;
            }
        }
    }
    cout << ans;
    return 0;
}

 

posted @ 2022-11-29 19:56  creation_hy  阅读(264)  评论(0编辑  收藏  举报