CodeChef SADPAIRS:Chef and Sad Pairs

vjudge
首先显然要建立圆方树
对于每一种点建立虚树,考虑这一种点贡献,对于虚树上已经有的点就直接算
否则对虚树上的一条边 \((u, v)\),\(u\) 为父亲,假设上面连通块大小为 \(x\),下面为 \(y\)
切断 \((u, v)\) 之间的点(不包括 \(u\))都会有 \(x\times y\) 的贡献,差分一下贡献即可

# include <bits/stdc++.h>
using namespace std;
typedef long long ll;

namespace IO {
	const int maxn(1 << 21 | 1);

	char ibuf[maxn], obuf[maxn], *iS, *iT, *oS = obuf, *oT = obuf + maxn - 1, c, st[66];
	int tp, f;

	inline char Getc() {
		return iS == iT ? (iT = (iS = ibuf) + fread(ibuf, 1, maxn, stdin), (iS == iT ? EOF : *iS++)) : *iS++;
	}

	template <class Int> inline void In(Int &x) {
		for (f = 1, c = Getc(); c < '0' || c > '9'; c = Getc()) f = c == '-' ? -1 : 1;
		for (x = 0; c >= '0' && c <= '9'; c = Getc()) x = (x << 1) + (x << 3) + (c ^ 48);
		x *= f;
	}

	inline void Flush() {
		fwrite(obuf, 1, oS - obuf, stdout);
		oS = obuf;
	}

	inline void Putc(char c) {
		*oS++ = c;
		if (oS == oT) Flush();
	}

	template <class Int> void Out(Int x) {
		if (!x) Putc('0');
		if (x < 0) Putc('-'), x = -x;
		while (x) st[++tp] = x % 10 + '0', x /= 10;
		while (tp) Putc(st[tp--]);
	}
}

using IO :: In;
using IO :: Out;
using IO :: Putc;
using IO :: Flush;

const int maxn(4e5 + 5);

int n, m, g[maxn], size[maxn], son[maxn], fa[maxn], top[maxn], tot, d[maxn], f[maxn], ng;
int dfn[maxn], low[maxn], st[maxn], tp, idx, deep[maxn], id[maxn], cnt, sz, nsz, dsu[maxn];
vector <int> e1[maxn], e2[maxn], e3[maxn], team[maxn * 3];
ll ans[maxn], cv[maxn], add;

inline void Add1(int u, int v) {
	e1[u].push_back(v), e1[v].push_back(u);
}

inline void Add2(int u, int v) {
	e2[u].push_back(v), e2[v].push_back(u);
}

inline void Add3(int u, int v) {
	e3[u].push_back(v), ++d[v];
}

void Tarjan(int u) {
	int cur;
	dfn[u] = low[u] = ++idx, st[++tp] = u;
	for (auto v : e1[u])
		if (!dfn[v]) {
			Tarjan(v), low[u] = min(low[u], low[v]);
			if (low[v] >= dfn[u]) {
				++tot;
				do {
					cur = st[tp--], Add2(cur, tot);
				} while (cur ^ v);
				Add2(u, tot);
			}
		}
		else low[u] = min(low[u], dfn[v]);
}

void Dfs1(int u, int ff) {
	size[u] = 1;
	for (auto v : e2[u])
		if (v ^ ff) {
			deep[v] = deep[u] + 1, fa[v] = u;
			Dfs1(v, u), size[u] += size[v];
			son[u] = size[v] > size[son[u]] ? v : son[u];
		}
}

void Dfs2(int u, int tp) {
	top[u] = tp, dfn[u] = ++idx;
	if (son[u]) Dfs2(son[u], tp);
	for (auto v : e2[u]) if (!top[v]) Dfs2(v, v);
}

inline int Find(int x) {
	return (dsu[x] ^ x) ? dsu[x] = Find(dsu[x]) : x;
}

inline int Cmp(int x, int y) {
	return dfn[x] < dfn[y];
}

inline int LCA(int u, int v) {
	while (top[u] ^ top[v]) deep[top[u]] > deep[top[v]] ? u = fa[top[u]] : v = fa[top[v]];
	return deep[u] < deep[v] ? u : v;
}

void Getsz(int u) {
	if (g[u] == ng) ++nsz;
	for (auto v : e3[u]) Getsz(v);
}

void Calc(int u) {
	if (g[u] == ng) f[u] = 1;
	for (auto v : e3[u]) Calc(v), ans[u] += (ll)f[u] * f[v], f[u] += f[v];
	for (auto v : e3[u]) cv[v] += (ll)(nsz - f[v]) * f[v], cv[u] -= (ll)(nsz - f[v]) * f[v];
}

void Solve(int u, int ff) {
	for (auto v : e2[u]) if (v ^ ff) Solve(v, u), cv[u] += cv[v];
	ans[u] += cv[u];
}

int main() {
	int i, j, u, v, l;
	In(n), In(m), tot = n;
	for (i = 1; i <= n; ++i) In(g[i]), team[g[i]].push_back(i), dsu[i] = i;
	for (i = 1; i <= m; ++i) {
		In(u), In(v), Add1(u, v);
		if (Find(u) ^ Find(v)) dsu[Find(u)] = Find(v);
	}
	for (i = 1; i <= n; ++i) if (!dfn[i]) Tarjan(i);
	for (idx = 0, i = 1; i <= tot; ++i) if (!size[i]) Dfs1(i, 0), Dfs2(i, i);
	for (i = 1; i <= 1000000; ++i)
		if (l = team[i].size()) {
			for (cnt = j = 0; j < l; ++j) id[++cnt] = team[i][j];
			sort(id + 1, id + cnt + 1, Cmp);
			for (j = 1; j < l; ++j) if (Find(id[j]) == Find(id[j + 1]))id[++cnt] = LCA(id[j], id[j + 1]);
			sort(id + 1, id + cnt + 1, Cmp), cnt = unique(id + 1, id + cnt + 1) - id - 1;
			for (j = 1; j <= cnt; ++j) e3[id[j]].clear(), f[id[j]] = d[id[j]] = 0;
			sz = tp = 0, ng = i;
			for (j = 1; j <= cnt; ++j) {
				while (tp && dfn[st[tp]] + size[st[tp]] <= dfn[id[j]]) --tp;
				if (tp) Add3(st[tp], id[j]);
				st[++tp] = id[j];
			}
			for (j = 1; j <= cnt; ++j)
				if (!d[id[j]]) nsz = 0, Getsz(id[j]), Calc(id[j]), add += (ll)sz * nsz, sz += nsz;
		}
	for (i = 1; i <= tot; ++i) if (!fa[i]) Solve(i, 0);
	for (i = 1; i <= n; ++i) Out(ans[i] + add), Putc('\n');
	return Flush(), 0;
}
posted @ 2019-01-16 14:51  Cyhlnj  阅读(239)  评论(0)    收藏  举报