luogu 3806 【模板】点分治
luogu 3806 【模板】点分治
给定一棵有n个点的树,有m个询问,每个询问树上距离为k的点对是否存在。树的权值最多不超过c。n<=10000,m<=100,c<=1000,K<=10000000。
关于树的路径的问题,点分治是一种最吼的工具。由于这道题的m比较小,枚举k,通过set保存每颗子树中点的路径值,在set中查询每个k值是否成立即可。似乎空间消耗很小,只用了2.3mb。
#include <set>
#include <cctype>
#include <cstdio>
#include <cstring>
using namespace std;
const int maxn=1e4+5;
struct Graph{
struct Edge{
int to, next, v; Graph *bel;
Edge& operator ++(){
return *this=bel->edge[next]; }
}edge[maxn*2];
int cnte, fir[maxn];
void addedge(int x, int y, int v){
Edge &e=edge[++cnte];
e.to=y; e.next=fir[x]; e.v=v;
fir[x]=cnte; e.bel=this;
}
Edge& getlink(int x){ return edge[fir[x]]; }
void RESET(){
cnte=0; memset(fir, 0, sizeof(fir)); }
}g;
int n, m, k[maxn], size[maxn], f[maxn];
int dep[maxn], tail;
set<int> s;
bool done[maxn], hasans[maxn];
//获取子树大小和最大子树大小
void predfs(int now, int par, int num){
Graph::Edge e=g.getlink(now);
size[now]=1; f[now]=0;
for (; e.to; ++e){
if (e.to==par||done[e.to]) continue;
predfs(e.to, now, num);
size[now]+=size[e.to];
f[now]=max(f[now], size[e.to]);
}
f[now]=max(f[now], num-size[now]);
}
//找到根
int getroot(int now, int par){
Graph::Edge e=g.getlink(now);
int core=now, t;
for (; e.to; ++e){
if (e.to==par||done[e.to]) continue;
t=getroot(e.to, now);
if (f[t]<f[core]) core=t;
}
return core;
}
//获取到所有点的深度
void getdep(int now, int par, int step){
Graph::Edge e=g.getlink(now);
for (; e.to; ++e) if (e.to!=par&&!done[e.to])
getdep(e.to, now, step+e.v);
dep[++tail]=step;
for (int i=1; i<=m; ++i)
if (s.find(k[i]-dep[tail])!=s.end())
hasans[i]=true;
}
void solve(int now, int par, int num){
predfs(now, 0, num); //预处理
if (size[now]==1) return;
now=getroot(now, 0); //找出重心
predfs(now, 0, num);
s.clear(); s.insert(0);
Graph::Edge e=g.getlink(now);
for (; e.to; ++e){
if (e.to==par||done[e.to]) continue;
tail=0;
getdep(e.to, now, e.v); //找出所有点的深度
for (int i=1; i<=tail; ++i)
s.insert(dep[i]);
}
//不能统计带有回头路的
e=g.getlink(now);
done[now]=true;
for (; e.to; ++e) if (e.to!=par&&!done[e.to])
solve(e.to, now, size[e.to]); //找子树
}
void get(int &x){
x=0; int flag=1; char c;
for (c=getchar(); !isdigit(c); c=getchar())
if (c=='-') flag=-flag;
for (x=c-48; c=getchar(), isdigit(c); )
x=(x<<3)+(x<<1)+c-48; x*=flag;
}
int main(){
get(n); get(m);
int t1, t2, t3;
for (int i=1; i<n; ++i){
get(t1); get(t2); get(t3);
g.addedge(t1, t2, t3);
g.addedge(t2, t1, t3);
}
for (int i=1; i<=m; ++i) get(k[i]);
solve(1, 0, n);
for (int i=1; i<=m; ++i) puts(hasans[i]?"AYE":"NAY");
return 0;
}