【BZOJ 1065】【Vijos 1826】【NOI 2008】奥运物流

http://www.lydsy.com/JudgeOnline/problem.php?id=1065
https://vijos.org/p/1826
好难的题啊TWT
∈我这辈子也想不出来系列~
题解详见2009年的论文徐源盛《对一类动态规划问题的研究》。


这道题我第一眼一脸懵逼。。。
然后持续懵逼。。。
最后看论文里的题解

我再来重复一下(论文绝对比我讲得详细QAQ)
首先\(R(i)=C_i+k\sum_\limits{j=1}^wR(p_j)\)
如果是单纯的一棵树,那么从孩子传到父亲非常方便。但是根节点1处会有向其他点的边形成环,这样怎么统计?
根据题解:把所有点的递推式写出来,然后迭代一下,把环上的点的R值全用1的R值和其他点的C值表示,然后把R(1)提出来得到:$$R(1)=\frac{\sum_\limits{j=1}nC_i*k{d(i,1)}}{1-k^{len}}$$
这样环长len确定时,只要分子尽可能大就可以了,这是第一步。
光第一步我就gg了QwQ

又因为\(k<1\),所以如果一个点的C值要对R(1)贡献尽可能的大,那么如果修改这个点,一定直接把这个点连向1,否则肯定不是最优的。
那么类似IOI2015河流那道题,子节点的状态的改变会对它的祖先的状态值有影响,需要新增状态来记录未来可能发生的情况。
设用\(f(i,j,d)\)表示以点i为根的树中,修改j次,且点i到1的距离为d的最大值。
为了方便转移令\(g(i,j,d)=max\{f(i,j,d+1),f(i,j,1)\}\)
\(f(i,j,d)=max\{g(s_1,j_1,d)+g(s_2,j_2,d)+\dots +g(s_t,j_t,d)\}+c(i)*k^d\)
\(s_1,s_2\dots s_t\)为i的t个儿子。
如果i节点不修改后继,\(j_1+j_2+\dots +j_t=j\)
如果i节点修改后继为1,\(j_1+j_2+\dots +j_t=j-1\),且\(d=1\)
为了避免转左儿子右兄弟,f的转移可以用FF数组来加速,\(FF(i,j)\)表示前i个孩子分配j次修改能得到的最大贡献。

套在最外面的一步是枚举环的长度。因为最终树的环上的点一定是原先的树中环上的点(环上的点是哪些不重要,重要的是环的长度,用原先环上的点可以节省修改次数给其他节点用)。
扫出以1为首节点的环(如图1,2,3,4,5,6)

然后枚举最终树上环的长度len,如图蓝线描出了最终的环,此时len为4

4的后继边指向了1。
然后对1到4所有节点为根的外向树进行dp,注意此时5,6及其外向树也应算作1的外向树。
如图,此时最优答案为\(f(1,j_1,0)+f(2,j_2,3)+f(3,j_3,2)+f(4,j_4,1)\)
其中\(\sum j=m-1\),(如果len为原环长,\(\sum j=m\)
再用dp算一下当前环长下的正确答案就可以啦。
时间复杂度\(O(n^3m^2)\),虽然看起来不可过,不过可以过的~

#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
using namespace std;
const int N = 63;

struct node {int nxt, to;} E[N];

int n, m, T[N], point[N], cnt, cir[N], ctot = 0, a[N], atot, deep[N];
double k, f[N][N][N], g[N][N][N], FF[N][N], C[N], ipow[N];
bool vis[N];

void ins(int u, int v) {E[++cnt] = (node) {point[u], v}; point[u] = cnt;}

void pre_dfs(int x) {
	for (int i = point[x]; i; i = E[i].nxt)
		if (!vis[E[i].to])
			deep[E[i].to] = deep[x] + 1, pre_dfs(E[i].to);
}

void dp(int x) {
	for (int i = point[x]; i; i = E[i].nxt)
		if (!vis[E[i].to])
			dp(E[i].to);
	
	atot = 0;
	for (int i = point[x]; i; i = E[i].nxt)
		if (!vis[E[i].to]) a[++atot] = E[i].to;
	
	for (int d = deep[x]; d >= 1; --d) {
		for (int i = 1; i <= atot; ++i)
			for (int l = 0; l <= m; ++l)
				FF[i][l] = 0;
		for (int i = 1; i <= atot; ++i)
			for (int l = 0; l <= m; ++l)
				for (int nowl = 0; nowl <= l; ++nowl)
					FF[i][l] = max(FF[i][l], FF[i - 1][l - nowl] + g[a[i]][nowl][d]);
		for (int j = 0; j <= m; ++j)
			f[x][j][d] = FF[atot][j] + C[x] * ipow[d];
	}
	
	if (x == 1) {
		for (int i = 1; i <= atot; ++i)
			for (int l = 0; l <= m; ++l)
				FF[i][l] = 0;
		for (int i = 1; i <= atot; ++i)
			for (int l = 0; l <= m; ++l)
				for (int nowl = 0; nowl <= l; ++nowl)
					FF[i][l] = max(FF[i][l], FF[i - 1][l - nowl] + g[a[i]][nowl][0]);
		for (int j = 0; j <= m; ++j)
			f[x][j][0] = FF[atot][j] + C[x];
	} else {
		for (int j = 1; j <= m; ++j)
		f[x][j][1] = max(f[x][j][1], FF[atot][j - 1] + C[x] * k);
	}
	
	for (int d = deep[x] - 1; d >= 0; --d) {
		g[x][0][d] = f[x][0][d + 1];
		for (int j = 1; j <= m; ++j)
			g[x][j][d] = max(f[x][j][d + 1], f[x][j - 1][1]);
	}
}

int main() {
	scanf("%d%d%lf", &n, &m, &k);
	for (int i = 1; i <= n; ++i) {
		scanf("%d", T + i);
		ins(T[i], i);
	}
	for (int i = 1; i <= n; ++i)
		scanf("%lf", C + i);
	ipow[0] = 1;
	for (int i = 1; i <= n; ++i)
		ipow[i] = ipow[i - 1] * k;
	
	int tmp = T[1]; cir[++ctot] = 1;
	while (tmp != 1) {
		cir[++ctot] = tmp;
		tmp = T[tmp];
	}
	
	double ans = 0, ret;
	vis[1] = true;
	for (int len = 2; len <= ctot; ++len) {
		vis[cir[len]] = true;
		pre_dfs(1);
		cnt = len;
		for (int i = 2; i <= len; ++i)
			deep[cir[i]] = --cnt, pre_dfs(cir[i]);
		memset(f, 0, sizeof(f));
		memset(g, 0, sizeof(g));
		
		for (int i = 1; i <= len; ++i)
			dp(cir[i]);
		
		cnt = len;
		memset(FF, 0, sizeof(FF));
		for (int i = 0; i <= m; ++i)
			FF[1][i] = f[1][i][0];
		for (int i = 2; i <= len; ++i) {
			--cnt;
			for (int j = 0; j <= m; ++j)
				for (int nowl = 0; nowl <= j; ++nowl)
					FF[i][j] = max(FF[i][j], FF[i - 1][j - nowl] + f[cir[i]][nowl][cnt]);
		}
		if (len < ctot) ans = max(ans, FF[len][m - 1] / (1.0 - ipow[len]));
		else ans = max(ans, FF[len][m] / (1.0 - ipow[len]));
//		printf("%.2lf\n", f[1][1][0]);
	}
	
	printf("%.2lf\n", ans);
	
//	for (int i = 1; i <= ctot; ++i)
//		printf("%d\n", cir[i]);
	
	return 0;
}
posted @ 2016-10-27 20:15  abclzr  阅读(250)  评论(0编辑  收藏  举报