P5056 插头dp

题面

Source:

unordered_map:

#include <iostream>
#include <tr1/unordered_map>
#include <cstdio>
#include <cstring>

using namespace std;

const int maxM = 200005;

#define LL long long

char mp[20][20];
int c[4] = {0, -1, 1, 0};
int n, m, ex, ey;
tr1::unordered_map<int, LL> Ht[2], T;

int now;
int set(int state, int x, int val) //使状态的第x位变成val
{ x <<= 1; return (state & (~(3<<x))) | (val << x); }
int get(int state, int x) //得到转态的第x位
{ x <<= 1; return (state >> x) & 3; }
int getl(int state, int x) //得到与x所配对的左括号的位置
{ int cnt = 1, l = x; while (cnt) cnt += c[get(state, --l)]; return l; }
int getr(int state, int x) //得到与x所匹配的右括号的位置
{ int cnt = -1, r = x; while (cnt) cnt += c[get(state, ++r)]; return r; }

void update(int x, int y, int state, LL val) {//状态转移(分类讨论), 刷表法
    int p = get(state, y), q = get(state, y + 1);
    if (mp[x][y] == '*') {//如果是障碍
        if (p == 0 && q == 0) Ht[now ^ 1][state] += val;//特判
        return;
    }

    if (p == 0 && q == 0) {//如果没有插头
        if (x == n - 1 || y == m - 1) return;
        int newst = set(state, y, 1);
        newst = set(newst, y + 1, 2);
        Ht[now ^ 1][newst] += val;
        return;
    }
    
    if (p == 0 || q == 0) {//如果只有一端有插头,则往右或往下插
        if (y < m - 1) {//往下插
            int newst = set(state, y, 0);
            newst = set(newst, y + 1, p + q);
            Ht[now ^ 1][newst] += val;
        }
        if (x < n - 1) {//往右插
            int newst = set(state, y, p + q);
            newst = set(newst, y + 1, 0);
            Ht[now ^ 1][newst] += val;
        }
        return;
    }
    
    int newst = set(state, y, 0); newst = set(newst, y + 1, 0);

    if (p == 1 && q == 1) //如果两个插头同为左括号,连起来后y+1对应的右插头要变成左插头
        newst = set(newst, getr(state, y + 1), 1);
    else if (p == 2 && q == 2) //如果两个插头同为右括号,连起来后y对应的左插头要变成右插头
        newst = set(newst, getl(state, y), 2);
    else if (p == 1 && q == 2 && (x != ex || y != ey)) return;//只有最后一个格子才能转移

    Ht[now ^ 1][newst] += val;
}

int main() {
#ifndef ONLINE_JUDGE
    freopen("BZOJ1814.in", "r", stdin);
#endif
    cin >> n >> m;
    for (int i = 0; i < n; ++ i) cin >> mp[i];
    for (int i = 0; i < n; ++ i)
        for (int j = 0; j < m; ++ j)
            if (mp[i][j] == '.') ex = i, ey = j;
    now = 0;
    Ht[now].clear();
    Ht[now][0] = 1;//别忘了
    for (int i = 0; i < n; ++ i) {
        //下面一部分是转移到下一行时的key<<=2
        T.clear();
        for (tr1::unordered_map<int, LL>::iterator it = Ht[now].begin(); it != Ht[now].end(); ++ it) 
            T[it->first<<2] = it->second;
        swap(T, Ht[now]);

        for (int j = 0; j < m; ++ j) {
            Ht[now ^ 1].clear();//记得转移之前清除
            for (tr1::unordered_map<int, LL>::iterator it = Ht[now].begin(); it != Ht[now].end(); ++ it) 
                update(i, j, it->first, it->second);
            now ^= 1;
        }
    }
    cout << Ht[now][0] << endl;//最后的轮廓线状态就是0
}

手码Hash_Table:

#include <iostream>
#include <cstdio>
#include <cstring>

using namespace std;

const int maxM = 200005;

#define LL long long

char mp[20][20];
int c[4] = {0, -1, 1, 0};
int n, m, ex, ey;

struct Hash_List {
    struct Node {
        int key, nxt;
        LL val;
    } data[maxM];
    int head[maxM], cnt;
    void init() { cnt = 0; memset(head, 0, sizeof head); }
    void insert(int key, LL val) {
        int x = key % maxM;
        for (int i = head[x]; i; i = data[i].nxt) 
            if (data[i].key == key) {
                data[i].val += val;
                return;
            }
        data[++cnt] = (Node) { key, head[x], val };
        head[x] = cnt;
    }
    LL getval(int key) {
        int x = key % maxM;
        for (int i = head[x]; i; i = data[i].nxt) {
            if (data[i].key == key) {
                return data[i].val;
            }
        }
        return 0;
    }
}DP[2];

int now;
int set(int state, int x, int val) 
{ x <<= 1; return (state & (~(3<<x))) | (val << x); }
int get(int state, int x) 
{ x <<= 1; return (state >> x) & 3; }
int getl(int state, int x) 
{ int cnt = 1, l = x; while (cnt) cnt += c[get(state, --l)]; return l; }
int getr(int state, int x) 
{ int cnt = -1, r = x; while (cnt) cnt += c[get(state, ++r)]; return r; }

void update(int x, int y, int state, LL val) {
    int p = get(state, y), q = get(state, y + 1);
    if (mp[x][y] == '*') {
        if (p == 0 && q == 0) DP[now ^ 1].insert(state, val);
        return;
    }

    if (p == 0 && q == 0) {
        if (x == n - 1 || y == m - 1) return;
        int newst = set(state, y, 1);
        newst = set(newst, y + 1, 2);
        DP[now ^ 1].insert(newst, val);
        return;
    }
    
    if (p == 0 || q == 0) {
        if (y < m - 1) {
            int newst = set(state, y, 0);
            newst = set(newst, y + 1, p + q);
            DP[now ^ 1].insert(newst, val);
        }
        if (x < n - 1) {
            int newst = set(state, y, p + q);
            newst = set(newst, y + 1, 0);
            DP[now ^ 1].insert(newst, val);
        }
        return;
    }
    
    int newst = set(state, y, 0); newst = set(newst, y + 1, 0);

    if (p == 1 && q == 1) 
        newst = set(newst, getr(state, y + 1), 1);
    if (p == 2 && q == 2)
        newst = set(newst, getl(state, y), 2);
    if (p == 1 && q == 2 && (x != ex || y != ey)) return;

    DP[now ^ 1].insert(newst, val);
}

int main() {
#ifndef ONLINE_JUDGE
    freopen("BZOJ1814.in", "r", stdin);
#endif
    cin >> n >> m;
    for (int i = 0; i < n; ++ i) cin >> mp[i];
    for (int i = 0; i < n; ++ i)
        for (int j = 0; j < m; ++ j)
            if (mp[i][j] == '.') ex = i, ey = j;
    now = 0;
    DP[now].init(); DP[now].insert(0, 1);
    for (int i = 0; i < n; ++ i) {
        for (int j = 1; j <= DP[now].cnt; ++ j) DP[now].data[j].key <<= 2;
        for (int j = 0; j < m; ++ j) {
            DP[now ^ 1].init();
            for (int k = 1; k <= DP[now].cnt; ++ k) 
                update(i, j, DP[now].data[k].key, DP[now].data[k].val);
            now ^= 1;
        }
    }
    cout << DP[now].getval(0) << endl;
}

posted @ 2019-02-23 15:57  茶Tea  阅读(105)  评论(0编辑  收藏  举报