P1433 吃奶酪(旅行商问题,压缩状态dp)

传送门

题目描述:

房间里放着 n 块奶酪。一只小老鼠要把它们都吃掉,问至少要跑多少距离?老鼠一开始在 (0,0)(0,0) 点处。

做题思路:

是所有点遍历求最短路径的裸题,我只想到暴力dfs,玄学优化,就是把距离最短的先访问,然后找到一条路径后大概率是最短的,

  这样得到的距离就能用来剪枝了,可以优化掉许多生成树,但是被最后一个测试点卡了,所以无奈只有看题解啦,

题解思路:状态压缩dp,

转移方程:

dp[s][i]=min(dp[s][i],dp[s|(1<<j)][j]+dis[i][j])

s表示还没有走过的点的集合,i表示当前位置,dp[s][j]表示以j为起点,还有s集合中的位置没有走

把点的集合s压缩成二进制表示,就能进行dp了,复杂度O(n^2)*(2^n),比暴力O(n!)快的多!

为什么一般s表示没有被选择的点,而不是表示选择了的点,也就是为什么一般从(1<<n-2)开始到0推,而不从1到(1<<n-1)推,

理论上是一样的,但是实际操作起来却有一些区别,假设我们取第一种(常用的),假设递推到s是010011(假设六个点),这种状态时,

i是左边第三个0,j是左边第四个个0,此时我们需要得到011011这个状态,即得到011011这个数,在我们的方程种可以看出我们只需要

s|(1<<j)就能得到了,那假如是第二种方式呢,同样此时s是010011,(1表示选择了的点),(注意这两个s虽然值一样,代表的意义完全不一样,不要认为是一样的,

我只是用来帮助我们理解这两个方式的优劣)那么递推的时候,就需要把s状态的j这个点在二进制种的位置由1边为0,

我们假设i是左边第一个1,j是左边第二个1,按照方程,我们现在需要得到010001这个数,才能进行递推,也就是从010011得到010001,然后我们发现

只需要s-(1<<j)就行了。。。对,也很简单!,我刚以为第一种大家都在用是因为它位运算比较方便,结果我自己写着写着发现直接相减不就能行了嘛,

然后还去把这道题用第二种方式做了一下,果然A了,但是减法确实会慢一些,但区别也不大,哈哈哈,,估计你看到这里想打我的冲动都有了吧

先不要慌动手,两种代码段我都贴下面了,看了具体实现再说

暴力+剪枝90分代码:

#include<iostream>
#include<string.h>
#include<cmath>
#include<set>
#include<map>
#include<string>
#include<queue>
#include<stack>
#include<vector>
#include<bitset>
#include<algorithm>
using namespace std;
typedef long long ll;
inline int read() {
    int sum = 0, f = 1;
    char p = getchar();
    for (; !isdigit(p); p = getchar()) if (p == '-')f = -1;
    for (; isdigit(p); p = getchar())  sum = sum * 10 + p - 48;
    return sum * f;
}
const int maxn = 16;
struct node {
    double d;
    int id;
    bool operator<(node b) {
        return d < b.d;
    }
};
node v[maxn][maxn];
struct point {
    double x, y;
}p[maxn];
bool vis[maxn];
double res = 9999999999;
int n;
void dfs(int now, int last, double len) {
    if (len > res)return;//剪枝
    if (last == 0) {
        res = len;
        return;
    }
    vis[now] = 1;
    for (int i = 1; i <= n; i++) {
        if (!vis[v[now][i].id]) {
            dfs(v[now][i].id, last - 1, len + v[now][i].d);
        }
    }
    vis[now] = 0;
}
int main() {
    //freopen("test.txt", "r", stdin);
    n = read();
    for (int i = 1; i <= n; i++) {
        scanf("%lf%lf", &p[i].x, &p[i].y);
    }
    p[0].x = p[0].y = 0;
    for (int i = 0; i <= n; i++) {
        for (int j = i; j <=n; j++) {
            double d = sqrt((p[i].y - p[j].y) *(p[i].y - p[j].y) + (p[i].x - p[j].x) * (p[i].x - p[j].x));
            v[i][j] = { d,j };
            v[j][i] = { d,i };
        }
    }
    for (int i = 0; i <= n; i++) {//拍个序,正常情况前面dfs树中求出的res更小一些,这样后面的dfs树就会被剪掉许多
        sort(v[i] + 1, v[i] + n + 1);
    }
    dfs(0, n, 0);
    printf("%.2lf\n", res);
    return 0;
}

 

压状AC代码:

#include<iostream>
#include<string.h>
#include<cmath>
#include<set>
#include<map>
#include<string>
#include<queue>
#include<stack>
#include<vector>
#include<bitset>
#include<algorithm>
using namespace std;
typedef long long ll;
inline int read() {
    int sum = 0, f = 1;
    char p = getchar();
    for (; !isdigit(p); p = getchar()) if (p == '-')f = -1;
    for (; isdigit(p); p = getchar())  sum = sum * 10 + p - 48;
    return sum * f;
}
const int maxn =17;
int n;
double x[maxn], y[maxn], dp[1 << maxn][maxn], dis[maxn][maxn];
void init() {
    n = read();
    for (int i = 1; i <= n; i++) {
        scanf("%lf%lf", &x[i], &y[i]);
    }
    x[0] = y[0] = 0;
    for (int i = 0; i <= n; i++) {
        for (int j =i+1; j <= n; j++) {
            dis[i][j] =dis[j][i]= sqrt((x[i] - x[j]) * (x[i] - x[j]) + (y[i] - y[j]) * (y[i] - y[j]));
        }
    }
    memset(dp, 127, sizeof(dp));
    for (int i = 0; i <= n; i++) {
        //状态表示法2的初始化
        dp[0][i] = 0;
        //状态表示法1的初始化
        //dp[(1 << (n + 1)) - 1][i] = 0;
    }
}
double solve() {
    //状态表示法2
    for (int s =1; s<= ((1 << (n + 1)) - 1); ++s) {
        for (int i = 0; i <= n; i++) {
            for (int j = 0; j <= n; j++) {
                //if (i == j)continue;
                if (s & (1 << j))
                    dp[s][i] = min(dp[s][i], dp[s-(1 << j)][j] + dis[i][j]);
            }
        }
    }
    //状态表示法1
    /*for (int s = ((1 << (n + 1)) - 2); ~s; --s) {
        for (int i = 0; i <= n; i++) {
            for (int j = 0; j <= n; j++) {
                if (!(s & (1 << j)))
                    dp[s][i] = min(dp[s][i], dp[s | (1 << j)][j] + dis[i][j]);
            }
        }
    }
    return dp[0][0];*/
    return dp[(1<<(n+1))-1][0];
}
int main() {
    //freopen("test.txt", "r", stdin);
    init();
    printf("%.2lf\n", solve());
    return 0;
}

 

 

 

 

posted @ 2021-03-14 12:37  cono奇犽哒  阅读(210)  评论(0)    收藏  举报