点分治和点分树
其实还是没怎么懂点分治和点分树,随便写点自己的理解。
求两个点的距离公式很明了 \(dis(u,v)=d_u+d_v-2\times d_{lca(u,v)}\),其中 \(d_x\) 表示根到 \(x\) 的距离。距离公式中最不好处理的其实是后面的 \(-2\times d_{lca(u,v)}\)。
转换下思路,如果 \(lca(u,v)\) 是树根是不是就会好求许多了。这时我们就需要一个算法实现一下几个操作:
- 找到一个新根
- 统计经过根的答案
- 删掉根,递归处理没经过根的路径
- 重复以上操作
显然每次操作的复杂度为根所在联通块的大小,那只要使根每次取重心,复杂度就会下降到 \(O(n\log n)\)。
因此点分治过程就是每次取出重心,处理出包含重心的连通块或路径的答案的过程。
回到这题中,考虑记录下重心所在联通块中所有点到重心的链,判断这些链是否能拼成一条长度为要求长度的路径。
点击查看代码
#include<bits/stdc++.h>
using namespace std;
const int maxn=1e7+10;
int root,sum,n,m,q[205],tot;
int head[maxn],mx[maxn],siz[maxn],cur[maxn],d[maxn],tmp[maxn];
bool vis[maxn],ans[maxn],ju[maxn];
struct node{
int u,v,w,next;
}e[maxn<<1];
int read(){
int s=0,f=1;
char ch=getchar();
while(ch>'9'||ch<'0'){
if(ch=='-'){
f=-1;
}
ch=getchar();
}
while(ch>='0'&&ch<='9'){
s=s*10+(ch-'0');
ch=getchar();
}
return s*f;
}
void add(int u,int v,int w){
e[++tot].u=u;
e[tot].v=v;
e[tot].w=w;
e[tot].next=head[u];
head[u]=tot;
}
void dfs(int u,int fa){
siz[u]=1;
mx[u]=0;
for(int i=head[u];i;i=e[i].next){
int v=e[i].v;
if(v==fa||vis[v])continue;
dfs(v,u);
siz[u]+=siz[v];
mx[u]=max(mx[u],siz[v]);
}
mx[u]=max(mx[u],sum-siz[u]);
if(mx[u]<mx[root]){
root=u;
}
}
void calc(int u,int fa){
cur[++cur[0]]=d[u];
for(int i=head[u];i;i=e[i].next){
int v=e[i].v;
if(v==fa||vis[v]){
continue;
}
d[v]=d[u]+e[i].w;
calc(v,u);
}
}
void work(int u){
int cnt=0;
for(int i=head[u];i;i=e[i].next){
int v=e[i].v;
if(vis[v]){
continue;
}
cur[0]=0;
d[v]=e[i].w;
calc(v,u);
for(int j=cur[0];j;j--){
for(int k=1;k<=m;k++){
if(q[k]>=cur[j]){
ans[k]|=ju[q[k]-cur[j]];
}
}
}
for(int j=cur[0];j;j--){
tmp[++cnt]=cur[j];
ju[cur[j]]=1;
}
}
for(int i=1;i<=cnt;i++){
ju[tmp[i]]=0;
}
}
void divid(int u){
vis[u]=ju[0]=1;
work(u);
for(int i=head[u];i;i=e[i].next){
int v=e[i].v;
if(vis[v]){
continue;
}
sum=siz[v];
mx[root=0]=1e8;
dfs(v,v);
divid(root);
}
}
int main(){
n=read(),m=read();
int u,v,w;
for(int i=1;i<n;i++){
u=read(),v=read(),w=read();
add(u,v,w);
add(v,u,w);
}
for(int i=1;i<=m;i++){
q[i]=read();
}
mx[root]=sum=n;
dfs(1,1);
divid(root);
for(int i=1;i<=m;i++){
if(ans[i]){
puts("AYE");
}
else{
puts("NAY");
}
}
return 0;
}
也是一道比较板的题,对于每个分治中心处理出联通块中每个点到分治中心的距离,得到 \(s_0,s_1,s_2\),\(s_i\) 表示距离分治中心距离模 \(3\) 为 \(i\) 的点的数量。贡献为 \(s_2\times s_1\times 2+s_0\times s_0\),再容斥掉没有经过分治重心的贡献即可。
点击查看代码
#include<bits/stdc++.h>
#define ull unsigned long long
#define ll long long
#define pdi pair<double,int>
#define pii pair<int,int>
#define pb push_back
#define mp make_pair
#define eps 1e-9
using namespace std;
namespace IO{
template<typename T>
inline void read(T &x){
x=0;
int f=1;
char ch=getchar();
while(ch>'9'||ch<'0'){
if(ch=='-'){
f=-1;
}
ch=getchar();
}
while(ch>='0'&&ch<='9'){
x=x*10+(ch-'0');
ch=getchar();
}
x=(f==1?x:-x);
}
template<typename T>
inline void write(T x){
if(x<0){
putchar('-');
x=-x;
}
if(x>=10){
write(x/10);
}
putchar(x%10+'0');
}
template<typename T>
inline void write_endl(T x){
write(x);
putchar('\n');
}
template<typename T>
inline void write_space(T x){
write(x);
putchar(' ');
}
}
using namespace IO;
const int N=2e4+10,inf=1e9;
int n,rt,siz[N],mx[N],del[N],tot,ans;
int d[N],s[10];
vector<pii>e[N];
void get_rt(int u,int fa){
siz[u]=1;
mx[u]=0;
for(auto x:e[u]){
int v=x.first;
if(del[v]||v==fa){
continue;
}
get_rt(v,u);
siz[u]+=siz[v];
mx[u]=max(mx[u],siz[v]);
}
mx[u]=max(mx[u],tot-siz[u]);
if(mx[u]<mx[rt]){
rt=u;
}
}
void dfs(int u,int fa){
s[d[u]%3]++;
for(auto x:e[u]){
int v=x.first,w=x.second;
if(del[v]||v==fa){
continue;
}
d[v]=d[u]+w;
d[v]%=3;
dfs(v,u);
}
}
int calc(int st,int w){
memset(s,0,sizeof(s));
d[st]=w;
dfs(st,0);
return s[1]*s[2]*2+s[0]*s[0];
}
void solve(int u){
get_rt(u,0);
ans=ans+calc(u,0);
del[u]=1;
for(auto x:e[u]){
int v=x.first,w=x.second;
if(del[v]){
continue;
}
ans=ans-calc(v,w);
rt=0;
tot=siz[v];
get_rt(v,0);
solve(rt);
}
}
signed main(){
#ifndef ONLINE_JUDGE
freopen("1.in","r",stdin);
freopen("1.out","w",stdout);
#endif
read(n);
for(int i=1;i<n;i++){
int u,v,w;
read(u),read(v),read(w);
e[u].pb(mp(v,w));
e[v].pb(mp(u,w));
}
tot=n;
rt=0;
mx[0]=inf;
get_rt(1,0);
solve(rt);
write(ans/__gcd(ans,n*n)),putchar('/'),write_endl(n*n/__gcd(ans,n*n));
return 0;
}
点分治好题,难点在于处理贡献。
先看题目要求求什么,树上路径数颜色,很容易想到树分块和树上莫队,但要求对树上所有路径都要求,果断放弃。
考虑点分治,对经过分治中心的路径一起算贡献。因为颜色数和树的大小是同阶的,所以不能一个颜色一个颜色来算贡献,否则复杂度会退化,只能所有颜色一起算贡献。
先处理路径对分治中心 \(u\) 的贡献。容易发现一个性质,在 \(u\) 的一个子树中所有的颜色相同的点中,只有深度最小的点 \(x\) 能给 \(u\) 造成贡献,贡献为 \(siz_x\)。记 \(cnt_c\) 为颜色 \(c\) 对 \(u\) 造成的贡献,则该连通块中经过分治中心的路径对答案造成的贡献为 \(sum=\sum\limits_{c=1}^{max_c}cnt_c\),需要注意的是 \(u\) 属于的颜色 \(col_u\) 对答案贡献为 \(siz_u\),因为任何一个点到分治中心必然经过分治中心。
接下来计算对非分治中心的贡献,因为一定经过分治中心,所以在计算前先去掉所在子树的贡献。对于子树 \(v\) 内一个点 \(x\),新增的贡献为 \(siz_u-siz_v-cnt_{col_x}\),其中 \(cnt_{col_x}\) 为去掉 \(v\) 子树内的贡献后,颜色 \(col_x\) 对 \(u\) 产生的贡献。需要注意的是,统计完子树 \(v\) 的贡献后要将子树 \(v\) 的贡献加回,方便后续计算。
点击查看代码
#include<bits/stdc++.h>
#define ull unsigned long long
#define int long long
#define pdi pair<double,int>
#define pii pair<int,int>
#define pb push_back
#define mp make_pair
#define eps 1e-9
using namespace std;
namespace IO{
template<typename T>
inline void read(T &x){
x=0;
int f=1;
char ch=getchar();
while(ch>'9'||ch<'0'){
if(ch=='-'){
f=-1;
}
ch=getchar();
}
while(ch>='0'&&ch<='9'){
x=x*10+(ch-'0');
ch=getchar();
}
x=(f==1?x:-x);
}
template<typename T>
inline void write(T x){
if(x<0){
putchar('-');
x=-x;
}
if(x>=10){
write(x/10);
}
putchar(x%10+'0');
}
template<typename T>
inline void write_endl(T x){
write(x);
putchar('\n');
}
template<typename T>
inline void write_space(T x){
write(x);
putchar(' ');
}
}
using namespace IO;
const int N=1e5+10,inf=1e9;
int col[N],n,del[N],ans[N];
int mx[N],rt,siz[N],tot;
int sum,cnt[N],Cnt[N];
vector<int>e[N];
void Get_rt(int u,int fa){
siz[u]=1;
mx[u]=0;
for(auto v:e[u]){
if(del[v]||v==fa){
continue;
}
Get_rt(v,u);
siz[u]+=siz[v];
mx[u]=max(mx[u],siz[v]);
}
mx[u]=max(mx[u],tot-siz[u]);
if(mx[u]<mx[rt]){
rt=u;
}
}
void Add(int u,int fa){
cnt[col[u]]++;
if(cnt[col[u]]==1&&col[u]!=col[rt]){
Cnt[col[u]]+=siz[u];
sum+=siz[u];
}
for(auto v:e[u]){
if(del[v]||v==fa){
continue;
}
Add(v,u);
}
cnt[col[u]]--;
}
void Del(int u,int fa){
cnt[col[u]]++;
if(cnt[col[u]]==1&&col[u]!=col[rt]){
Cnt[col[u]]-=siz[u];
sum-=siz[u];
}
for(auto v:e[u]){
if(del[v]||v==fa){
continue;
}
Del(v,u);
}
cnt[col[u]]--;
}
void Update(int u,int fa,int belong){
cnt[col[u]]++;
if(cnt[col[u]]==1&&col[u]!=col[rt]){
sum+=tot-siz[belong]-Cnt[col[u]];
}
ans[u]+=sum;
for(auto v:e[u]){
if(del[v]||v==fa){
continue;
}
Update(v,u,belong);
}
if(cnt[col[u]]==1&&col[u]!=col[rt]){
sum-=tot-siz[belong]-Cnt[col[u]];
}
cnt[col[u]]--;
}
void calc(int u){
sum=siz[u];
for(auto v:e[u]){
if(del[v]){
continue;
}
Add(v,u);
}
ans[u]+=sum;
for(auto v:e[u]){
if(del[v]){
continue;
}
sum-=siz[v];
Del(v,u);
Update(v,u,v);
Add(v,u);
sum+=siz[v];
}
for(auto v:e[u]){
if(del[v]){
continue;
}
Del(v,u);
}
}
void Divid(int u){
del[u]=1;
Get_rt(u,0);
calc(u);
for(auto v:e[u]){
if(del[v]){
continue;
}
tot=siz[v];
rt=0;
Get_rt(v,0);
Divid(rt);
}
}
signed main(){
#ifndef ONLINE_JUDGE
freopen("1.in","r",stdin);
freopen("1.out","w",stdout);
#endif
read(n);
for(int i=1;i<=n;i++){
read(col[i]);
}
for(int i=1,u,v;i<n;i++){
read(u),read(v);
e[u].pb(v);
e[v].pb(u);
}
mx[0]=inf;
rt=0;
tot=n;
Get_rt(1,0);
Divid(rt);
for(int i=1;i<=n;i++){
write_endl(ans[i]);
}
return 0;
}