Codeforces Round #511 (Div. 1) C. Region Separation(dp + 数论)

题意

一棵 \(n\) 个点的树,每个点有权值 \(a_i\) 。你想砍树。

你可以砍任意次,每次你选择一些边断开,需要满足砍完后每个连通块的权值和是相等的。求有多少种砍树方案。

\(n \le 10^6, a_i \le 10^9\)

题解

先假设只砍一次。令所有点权和为 \(S\) ,那么假设要砍成 \(k\) 个连通块,则每个连通块的权值和均为 \(\displaystyle \frac{S}{k}\)

考虑如何得到砍的方案,以 \(1\) 号点为根 \(dfs\) ,若当前点 \(i\) 的子树之和 \(\frac{S}{k} | \displaystyle sum_i\) ,则当前子树可以砍下来。若最后恰好砍了 \(k\) 次,那么就得到了一个合法的砍树方案。

其实这就等价于 \(\displaystyle \sum_{i=1}^{n} [\frac{S}{k} | sum_i] = k\)

不难看出这个对应且仅对应一种方案。如果不足 \(k\) ,那么就没有那么多个点可以分;多于 \(k\) 的情况是不可能的,因为总和不够分配。

这个式子还不够优秀,我们转化一下:

\[\begin{align} [\frac{S}{k}|sum_i] &= [S | k \times sum_i] \\ &= [\frac{S}{\gcd(S,sum_i)}|k \times \frac{sum_i}{\gcd(S,sum_i)}] \\ &\because \frac{S}{\gcd(S,sum_i)} \bot \frac{sum_i}{\gcd(S,sum_i)} \\ &= [\frac{S}{\gcd(S,sum_i)} | k] \end{align} \]

然后就变成

\[\sum_{i = 1}^{n} [\frac{S}{\gcd(S,sum_i)} | k] = k \]

显然这个我们可以枚举倍数在 \(O(n \ln n)\) 的时间内解决(注意 \(k \le n\)

那么如果砍多次呢?可以看出如果第一次砍成了 \(x\) 块,那么第二次砍成的块数 \(y\) 必须满足 \(x|y\)

因为你之后的权值只能比之前分的更多,且每个联通块的权值是之前的一个因子。

这部分也可以 \(O(n \ln n)\) 算。

总结

熟悉这种分成很多块有关于 \(O(\ln n)\) 复杂度的东西就行啦qwq

代码

#include <bits/stdc++.h>

#define For(i, l, r) for(register int i = (l), i##end = (int)(r); i <= i##end; ++i)
#define Fordown(i, r, l) for(register int i = (r), i##end = (int)(l); i >= i##end; --i)
#define Set(a, v) memset(a, v, sizeof(a))
#define Cpy(a, b) memcpy(a, b, sizeof(a))
#define debug(x) cout << #x << ": " << (x) << endl
#define DEBUG(...) fprintf(stderr, __VA_ARGS__)

using namespace std;

typedef long long ll;

template<typename T> inline bool chkmin(T &a, T b) {return b < a ? a = b, 1 : 0;}
template<typename T> inline bool chkmax(T &a, T b) {return b > a ? a = b, 1 : 0;}

inline int read() {
	int x(0), sgn(1); char ch(getchar());
	for (; !isdigit(ch); ch = getchar()) if (ch == '-') sgn = -1;
	for (; isdigit(ch); ch = getchar()) x = (x * 10) + (ch ^ 48);
	return x * sgn;
}

void File() {
#ifdef zjp_shadow
	freopen ("C.in", "r", stdin);
	freopen ("C.out", "w", stdout);
#endif
}

const int N = 1e6 + 1e3;

bitset<N> pass;

ll sum[N], dp[N]; int n, fa[N];

int main () {

	File();

	n = read();
	For (i, 1, n) sum[i] = read();
	For (i, 2, n) fa[i] = read();
	Fordown (i, n, 1) sum[fa[i]] += sum[i];

	For (i, 1, n) {
		ll tmp = sum[1] / __gcd(sum[1], sum[i]);
		if (tmp <= n) ++ dp[tmp];
	}

	Fordown (i, n, 1) if (dp[i])
		for (int j = i * 2; j <= n; j += i) dp[j] += dp[i];

	For (i, 1, n)
		pass[i] = (dp[i] == i && !(sum[1] % i)), dp[i] = 0;
	dp[1] = pass[1];

	ll ans = 0;
	For (i, 1, n) if (pass[i]) {
		for (int j = i * 2; j <= n; j += i) 
			if (pass[j]) dp[j] += dp[i];
		ans += dp[i];
	}
	printf ("%lld\n", ans);

	return 0;

}
posted @ 2018-10-11 16:46  zjp_shadow  阅读(506)  评论(0编辑  收藏  举报