51nod 1601 完全图的最小生成树计数

从高到低考虑每一位,把当前位为 \(1\) 的点集称为 \(S\), 当前位为 \(0\) 的点集称为 \(T\),那么最小生成树就是 \(S\) 的生成树 + \(T\) 的生成树 + \(S\)\(T\) 之间连一条最短的边。
前两个部分递归处理,最后一部分把 \(T\) 插入字典树,枚举 \(S\) 中的点在字典树中查找即可。
当有一些点权值相同时,就是完全图的生成树计数,根据prufer序列就是 \(n^{n-2}\) 个。

#include <bits/stdc++.h>
#define pii pair<int, int>
#define ll long long

const int N = 1e5 + 7;
const int MOD = 1e9 + 7;
const int INF = 0x3f3f3f3f;

int qp(int a, int b) {
	int ans = 1;
	while (b) {
		if (b & 1) ans = 1LL * ans * a % MOD;
		b >>= 1;
		a = 1LL * a * a % MOD;
	} 
	return ans % MOD;
}

struct Trie {
	static const int _N = N * 30;
	int ch[_N][2], tol, cnt[_N];
	void init() {
		for (int i = 0; i <= tol; i++)
			ch[i][0] = ch[i][1] = cnt[i] = 0;
		tol = 0;
	}
	void insert(int x) {
		int rt = 0;
		for (int i = 30; ~i; i--) {
			int id = x >> i & 1;
			if (!ch[rt][id]) ch[rt][id] = ++tol;
			rt = ch[rt][id];
		}
		cnt[rt]++;
	}
	std::pii query(int x) {
		int ans = 0, rt = 0;
		for (int i = 30; ~i; i--) {
			int id = x >> i & 1;
			if (ch[rt][id]) rt = ch[rt][id];
			else rt = ch[rt][id ^ 1], ans += (1 << i);
		}
		return std::pii(ans, cnt[rt]);
	}
} trie;

int n, a[N], cnt = 1, one[N], zero[N];
ll ans;

inline void M(int &a) {
	if (a >= MOD) a -= MOD;
	if (a < 0) a += MOD;
}

void solve(int l, int r, int dep) {
	if (l >= r) return;
	if (dep < 0) {
		if (r - l - 1 >= 1) cnt = 1LL * cnt * qp(r - l + 1, r - l - 1) % MOD;
		return;
	}
	int cnt1 = 0, cnt2 = 0;
	for (int i = l; i <= r; i++) {
		if (a[i] >> dep & 1) one[cnt1++] = a[i];
		else zero[cnt2++] = a[i];
	}
	trie.init();
	for (int i = 0; i < cnt1; i++)
		a[i + l] = one[i];
	for (int i = 0; i < cnt2; i++)
		a[i + l + cnt1] = zero[i], trie.insert(zero[i]);
	std::pii p(INF, 0);
	for (int i = 0; i < cnt1; i++) {
		std::pii q = trie.query(one[i]);
		if (q.first < p.first)
			p = q;
		else if (q.first == p.first)
			M(p.second += q.second);
	}
	if (p.first != INF) ans += p.first, cnt = 1LL * cnt * p.second % MOD;
	solve(l, l + cnt1 - 1, dep - 1);
	solve(l + cnt1, r, dep - 1);
}

int main() {
	scanf("%d", &n);
	for (int i = 1; i <= n; i++)
		scanf("%d", a + i);
	solve(1, n, 30);
	printf("%lld\n%d\n", ans, cnt);
	return 0;
}
posted @ 2020-02-05 21:48  Mrzdtz220  阅读(132)  评论(0)    收藏  举报