分组(状压dp+技巧:快速枚举子集)
有n个物品,把n个物品分成若干组,如果第i个物品和第j个物品分在同一组的话,那么你会得到a[i][j]个金币((i,j)和(j,i)的贡献只算一次),问如何分组得到的金币最多。
输入格式
第一行一个数n
接下来n行,每行n个数,表示a[i][j]
1<=n<=16,-1e9<=a[i][j]<=1e9,a[i][i]=0,a[i][j]=a[j][i]
输出格式
一个整数
输入/输出例子1
输入:
3
0 10 20
10 0 -100
20 -100 0
输出:
20
输入/输出例子2
输入:
2
0 -10
-10 0
输出:
0
输入/输出例子3
输入:
4
0 1000000000 1000000000 1000000000
1000000000 0 1000000000 1000000000
1000000000 1000000000 0 -1
1000000000 1000000000 -1 0
输出:
4999999999
样例解释
无
一个小知识,也是考场上不会的,虽然对这题没影响:
比如一个状态S,S=0010,代表的是第二位取了,而不是第二位
一般是按倒序来的!!(毕竟代码里面也是 S & (1<<i)
4 3 2 1
0 0 1 0
代表2选了
这题跟前i个物品是没得关系的,我们可以直接定一个状态 S,表示整个序列选择了的物品。
那么状态 f[S]就出来了,表示 选的状态是S时的最大价值
为了方便,我们定 val(S),表示选的物品为S的时候的价值
转移也很简单,我们举个例看看
f(7)=f(4)+val1,2 = f(3)+val3 = f(1)+val2,3 =........
形象的:
f(111)=f(100)+val(011) = f(011)+val(100) = f(001)+val(110) =.....
也就是说,一个状态S,等于这个状态的子集S2,加上S有的物品S2却没有的物品的权值
那么S有的物品S2却没有,可以用异或表示:S^S2,因为只要S,S2同一位不同,就证明S有S2没有,那么为什么S一定有,而不是S2由S没用?因为S2是S的子集
那么现在我们要求出val
这个预处理即可,还是很好搞的。
也就是枚举val的第i位和第j位,然后i,j两个物品的权值相乘后加入val即可。
我们现在处理 f 数组,首先枚举一个S,然后枚举一个S2。
正常枚举,用判断子集的方法去枚举子集
判断子集:
s2 & s == s2
s2是s的子集
就是O(4^n),肯定是炸掉了(枚举S,2^n,枚举S2,2^n)
那么考虑优化
对于S,是必须枚举的,没得优化,我们关注S2
发现把S2所有状态都搞了,但很多都没用,可以在枚举上进行优化,尽量只枚举有用的
这里先给出
O(3^n)的枚举子集方法:S2=(S2-1)&S
感性分析下,首先最后按位与上S,确保了它一定是S的子集,然后S2不断递减,保证了不会漏掉每一个状态,肯定也不会重复一个状态
举个例:
s=7,s2变化过程:
111
110
101
100
011
010
001
发现s2每次都是s的子集且没有漏
复杂度的证明:
我们分类讨论:
对集合中枚举的每一个子集的1的个数进行讨论(这样其实就是讨论的每个子集是多少,因为确定了1后,也就确定了0)
假设枚举 1~2^n-1
对出现的子集进行分类:
举个例子,假设出现k个1,那么这k个1可以从n个1里面选,看看具体选的哪些1,选出来后,有k个位置给你放,也就是2^k,也就是 C(n, k)*2^k
0个1:C(n, 0)*2^0
1个1:C(n, 1)*2^1
2个1:C(n, 2)*2^2
.......
n个1:C(n, n)*2^n
累加后,也就是
C(n, 0)*2^0 + C(n, 1)*2^1 + C(n, 2)*2^2 + ... + C(n, n)*2^n
转换一下:
C(n, 0)*2^0 * 1^n + C(n, 1)*2^1 * 1^(n-1) + C(n, 2)*2^2 * 1^(n-2) + ... + C(n, n)*2^n * 1^(n-n)
根据二项式定理可得:
原式=(1+2)^n=3^n
理解完如何枚举子集后,就做出来了。
#include <bits/stdc++.h> using namespace std; const int N=20, M=135000; int n, a[N][N], m=0; long long g[M], f[M]; int main() { scanf("%d", &n); m=(1<<n)-1; for (int i=1; i<=n; i++) for (int j=1; j<=n; j++) scanf("%d", &a[i][j]); for (int s=1; s<=m; s++) for (int i=1; i<=n; i++) if (s & (1<<(i-1))) for (int j=i+1; j<=n; j++) if (s & (1<<(j-1))) g[s]+=a[i][j]; for (int s=1; s<=m; s++) f[s]=g[s]; for (int s=1; s<=m; s++) for (int s2=s; s2>=1; s2=(s2-1)&s) f[s]=max(f[s], f[s2]+g[s^s2]); printf("%lld", f[m]); return 0; }