【BZOJ 2599】【IOI 2011】Race 点分治
裸的点分治,然而我因为循环赋值$s$时把$i <= k$写成$i <= n$了,WA了好长时间
#include<cstdio>
#include<cstring>
#include<algorithm>
#define N 200100
#define inf 2147483647
#define max(a,b) (a)>(b)?(a):(b)
#define min(a,b) (a)<(b)?(a):(b)
#define read(x) x=getint()
using namespace std;
inline int getint() {
int fh = 1, k = 0; char c = getchar();
for(; c < '0' || c > '9'; c = getchar())
if (c == '-') fh = -1;
for(; c >= '0' && c <= '9'; c = getchar())
k = k * 10 + c - '0';
return k * fh;
}
struct node {
int nxt, to, w;
} E[N << 1];
bool vis[N];
int cnt = 0, s[1000100], rtm = inf, root, sz[N], dist[N], deep[N], n, k, ans, point[N];
inline void ins(int x, int y, int z) {++cnt; E[cnt].nxt = point[x]; E[cnt].to = y; E[cnt].w = z; point[x] = cnt;}
inline void fdrt(int x, int fa, int sh) {
sz[x] = 1;
int ma = 0;
for(int tmp = point[x]; tmp; tmp = E[tmp].nxt) {
int v = E[tmp].to;
if (vis[v] || v == fa)
continue;
fdrt(v, x, sh);
sz[x] += sz[v];
ma = max(ma, sz[v]);
}
ma = max(ma, sh - ma);
if (ma < rtm) {
rtm = ma;
root = x;
}
}
inline void work(int x, int fa) {
if (dist[x] <= k)
ans = min(ans, deep[x] + s[k - dist[x]]);
for(int tmp = point[x]; tmp; tmp = E[tmp].nxt) {
int v = E[tmp].to;
if (vis[v] || v == fa)
continue;
dist[v] = dist[x] + E[tmp].w;
deep[v] = deep[x] + 1;
work(v, x);
}
}
inline void sfill(int x, int fa) {
if (dist[x] < k)
s[dist[x]] = min(s[dist[x]], deep[x]);
for(int tmp = point[x]; tmp; tmp = E[tmp].nxt) {
int v = E[tmp].to;
if (vis[v] || v == fa)
continue;
sfill(v, x);
}
}
inline void emp(int x, int fa) {
if (dist[x] < k)
s[dist[x]] = n + 1;
for(int tmp = point[x]; tmp; tmp = E[tmp].nxt) {
int v = E[tmp].to;
if (vis[v] || v == fa)
continue;
emp(v, x);
}
}
inline void dfs(int x, int sh) {
vis[x] = 1;
s[0] = 0; //不能落下这个点!!因为后面会更新不到,而且有可能会更改s[0]的值
for(int tmp = point[x]; tmp; tmp = E[tmp].nxt) {
int v = E[tmp].to;
if (vis[v])
continue;
dist[v] = E[tmp].w;
deep[v] = 1;
work(v, x);
sfill(v, x);
}
for(int tmp = point[x]; tmp; tmp = E[tmp].nxt) {
int v = E[tmp].to;
if (vis[v])
continue;
emp(v, x);
}
for(int tmp = point[x]; tmp; tmp = E[tmp].nxt) {
int v = E[tmp].to;
if (vis[v])
continue;
int ss = sz[v] < sz[x] ? sz[v]: sh - sz[x];
rtm = inf;
fdrt(v, x, ss);
dfs(root, ss);
}
}
int main() {
read(n); read(k);
int a,b,c;
for(int i = 1; i < n; ++i) {
read(a); read(b); read(c); ++a; ++b;
ins(a, b, c);
ins(b, a, c);
}
ans = n;
memset(vis, 0, sizeof(vis));
fdrt(1, -1, n);
for(int i = 0; i <= k; ++i)
s[i] = n + 1;
dfs(1, n);
printf("%d\n", ans == n ? -1 : ans);
return 0;
}
然后就可以了
NOI 2017 Bless All

浙公网安备 33010602011771号