[LOJ#2386]. 「USACO 2018.01 Platinum」Cow at Large[点分治]

题意

题目链接

分析

  • 假设当前的根为 rt ,我们能够在奶牛到达 \(u\) 之时拦住它,当且仅当到叶子节点到 \(u\) 的最短距离 \(mn_u \le dis_u\) 。容易发现,合法的区域是许多棵子树,而我们要求的就是有多少棵子树。
  • 由于除了以 rt 为根的子树都可以用 \(\sum\limits_{x\in subtree} 2-deg(x)\) 的形式表示 (如果 rt 是叶子特判掉即可),于是可以将问题转化成有多少个点满足 \(mn_u\le dis_u​\)
  • 考虑点分治,先补集转化这样不用处理负权。每次求对于 \(a\) 有多少 \(b\) 满足 \(dis_a+dis_b<mn_b\) 。把所有路径按照 \(mn_b-dis_b​\) 的大小排序后在序列上二分差后缀和即可。最后要容斥减去子树内的方案数。
  • 时间复杂度 \(O(nlog^2n)​\)

代码

#include<bits/stdc++.h>
using namespace std;
typedef long long LL;
#define go(u) for(int i = head[u], v = e[i].to; i; i=e[i].lst, v=e[i].to)
#define rep(i, a, b) for(int i = a; i <= b; ++i)
#define pb push_back
#define re(x) memset(x, 0, sizeof x)
inline int gi() {
    int x = 0,f = 1;
    char ch = getchar();
    while(!isdigit(ch)) { if(ch == '-') f = -1; ch = getchar();}
    while(isdigit(ch)) { x = (x << 3) + (x << 1) + ch - 48; ch = getchar();}
    return x * f;
}
template <typename T> inline bool Max(T &a, T b){return a < b ? a = b, 1 : 0;}
template <typename T> inline bool Min(T &a, T b){return a > b ? a = b, 1 : 0;}
const int N = 7e4 + 7, inf = 0x3f3f3f3f;
int n, rt, edc, sn;
int mn[N], deg[N], ans[N], mxs[N], head[N], son[N];
bool vis[N];
struct edge {
	int lst, to;
	edge(){}edge(int lst, int to):lst(lst), to(to){}
}e[N << 1];
void Add(int a, int b){
	++deg[a], ++deg[b];
	e[++edc] = edge(head[a], b), head[a] = edc;
	e[++edc] = edge(head[b], a), head[b] = edc;
}
void bfs() {
	queue<int>Q;
	memset(mn, 0x3f, sizeof mn);
	rep(i, 1, n) if(deg[i] == 1) {
		mn[i] = 0, Q.push(i);
	}
	while(!Q.empty()) {
		int u = Q.front();Q.pop();
		go(u)if(mn[v] == inf) {
			mn[v] = mn[u] + 1;
			Q.push(v);
		}
	}
}
int tp;
typedef pair<int, int> pii;
#define mp make_pair
pii suf[N];
void getrt(int u, int fa) {
	mxs[u] = 0;son[u] = 1;
	go(u)if(!vis[v] && v ^ fa) {
		getrt(v, u);
		son[u] += son[v];
		Max(mxs[u], son[v]);
	}
	Max(mxs[u], sn - son[u]);
	if(mxs[u] < mxs[rt]) rt = u;
}
void getdep(int u, int fa, int dis) {
	if(mn[u] - dis > 0) suf[++tp] = mp(mn[u] - dis, 2 - deg[u]);
	go(u)if(!vis[v] && v ^ fa) {
		getdep(v, u, dis + 1);
	}
}
void getans(int u, int fa, int dis, int f) {
	int gg = upper_bound(suf + 1, suf + 1 + tp, mp(dis, inf)) - suf;
	if(gg != tp + 1)
	ans[u] += f * suf[gg].second;
	go(u)if(!vis[v] && v ^ fa) {
		getans(v, u, dis + 1, f);
	}
}
void solve(int u) {
	vis[u] = 1;
	tp = 0;
	getdep(u, 0, 0);
	sort(suf + 1, suf + 1 + tp);
	for(int j = tp - 1; j >= 1; --j) suf[j].second += suf[j + 1].second;
	getans(u, 0, 0, 1);
	
	go(u)if(!vis[v]) {
		tp = 0;
		getdep(v, u, 1);
		sort(suf + 1, suf + 1 + tp);
		for(int j = tp - 1; j >= 1; --j) suf[j].second += suf[j + 1].second;
		getans(v, u, 1, -1);
	}
	
	int old = sn;
	go(u)if(!vis[v]) {
		if(son[v] > son[u])
			sn = old - son[u];
		else
			sn = son[v];
		rt = 0, getrt(v, u), solve(rt);
	}
}
int main() {
	n = gi();
	rep(i, 1, n - 1) Add(gi(), gi());
	bfs();
	sn = n, mxs[rt = 0] = n + 1, getrt(1, 0), solve(rt);
	rep(i, 1, n) printf("%d\n", deg[i] == 1 ? 1 : 2 - ans[i]);
	return 0;
}
posted @ 2019-03-17 20:45  fwat  阅读(256)  评论(0编辑  收藏  举报