做题随笔:P8981

Solution

写在前面

本文结论没有证明,想要详细证明的请看_Vix_大佬的文章。本文的基本思路与大佬相同,但偏向于思路分析,还有对分类简化的尝试。

题意

原题链接

给定一棵树,每个点的点权为经过该点的直径数量,求树的点权和或点权平方和。

分析&Code 1

先放一张图,明确一下叫法:

深度为 2 的节点,本文中称为 top 节点,top 节点的子树简称为 top 子树,文中提到的“最长链”“次长链”的一段为根节点。

题目很直接,我们也有个直接的想法:我只要求出所有的直径,显示地加,一定可以做(废话,\(O(n^2)\),滚吧)。所以说问题就转变为:怎么确定直径和怎么方便地加。

第二问我会!树上差分!只要找到直径的两个端点,他们之间的链必然是直径,拿下!

所以其实只有一个问题:怎么确定直径。

我们可以想到直径的一个性质:直径必然过中点(如果有两个中点,则都经过)。所以我们考虑换根,以(其中一个)中心为根建树。这样有两个好处:直径必然过根,不用写 LCA 了;直径必然由从根出发的两条最长链(一个中心时)或一条最长链加上一条次长链(两个中心时)组成。于是我们也得到了确定直径端点的方法:换根后,深度等于直径一半或直径一半减一(分别对应最长链端点、次长链端点)的叶节点

那么现在我们有了一个朴素的方法:换根,dfs 算一个深度,\(O(n^2)\) 扫,差分算贡献。在此基础上,我们发现:一个中心时,过某个直径端点,可以与所有不在和同一 top 子树下的其他直径端点形成一条直径。其实很直观,这其实就是两条最长链;两个中心时,有且仅有一个 top 子树下会有最长链(否则直径由两条最长链组成,应只有一个中心),直径由该 top 子树中的最长链端点和所有其他 top 子树的次长链端点组成。

于是,我们可以统计下所有 top 子树中的最长链端点数 \(cnt_1(u)\) 和次长链端点数 \(cnt_2(u)\)(下文中 \(cnt_1(rt)\) 表示 \(\sum cnt_1(u)\)\(cnt_2(rt)\) 同理),进行如下操作:

  • 若只有一个中心,在所有最长链端点 \(u\) 对应的差分数组上加上 \(cnt_1(rt)-cnt_1(u)\)
  • 若有两个中心,在所有最长链端点 \(u\) 对应的差分数组上加上 \(cnt_2(rt)\),次长链加 \(cnt_1(rt)\)

然后差分统计答案即可。

最后,由于根节点在每条直径都被加了两次,最后记得乘上 2 的乘法逆元(致直接除调了一下午的我)。

Code

(以下是合并 \(cnt_1(u),cnt_2(u)\) 的方法)

但是但是,还要判断叶子是在是太麻烦了,一堆数组给自己都整晕了!!!有没有什么简单又强势的办法?有的兄弟有的。

我们发现,有两个中点时,由于两个中点都必然经过,所以它们就是一个点!用兄弟儿子法存树,直接删一个点,全部按只有一个中心做!

但是,我写的前向星怎么办(比如本蒟蒻)?好办!

把两个中心之间的边断了(不让走就行了,并且只有一个中心时就不存在这条边),整成两棵树,这时候两树的最长链长度相等,遍历所有点,在最长链端点对应的差分数组上加上另一棵树的 \(cnt(rt)\) 即可。需要注意:这么写只有在有一个中心时根会重复加,需要判断一下。

复杂度 \(O(n)\)

Code 2(对应上面的第二种方法)

#include <iostream>
#include <cctype>
#include <cstdio>
#include <climits>

typedef long long ll;

int fr() {
    int x=0,f=1;
    char c=getchar();
    while(!isdigit(c)) {
        if(c=='-') f=-1;
        c=getchar();
    } 
    while(isdigit(c)) {
        x=(x<<3)+(x<<1)+(c^48);
        c=getchar();
    }
    return x*f;
}

const int maxn=5e6+100;
const int M=998244353;

int head[maxn],tot;

struct edge{
    int v,nxt;
}e[maxn*2];

void ade(int u,int v) {
    e[++tot]={v,head[u]};
    head[u]=tot;
};

int n,k;

int d1[maxn],d2[maxn];

void dfs1(int u,int f) {
    for(int i = head[u]; i; i=e[i].nxt) {
        int v=e[i].v;
        if(v==f) continue;
        dfs1(v,u);
        if(d1[v]+1>d1[u]) {
            d2[u]=d1[u];
            d1[u]=d1[v]+1;
        }
        else if(d1[v]+1>d2[u]) d2[u]=d1[v]+1;
    }
}

int up[maxn];

void dfs2(int u,int f) {
    for(int i = head[u]; i; i=e[i].nxt) {
        int v=e[i].v;
        if(f==v) continue;
        up[v]=up[u]+1;
        if(d1[v]+1!=d1[u]) up[v]=std::max(up[v],d1[u]+1);
        else up[v]=std::max(up[v],d2[u]+1);
        dfs2(v,u);
    }
}

int rt[2],min_l=INT_MAX;

void getr() {
    dfs1(1,0);
    dfs2(1,0);
    for(int i = 1; i <= n; i++) {
        if(std::max(d1[i],up[i])<min_l) {
            min_l=std::max(d1[i],up[i]);
            rt[0]=rt[1]=i;
        }
        else if(std::max(d1[i],up[i])==min_l) rt[1]=i;
    }
}
//就是和求重心类似的方式,但是不知道为什么机房佬都不认可

int cnt[maxn],top[maxn];
int dep[maxn],fa[maxn];
//fa:每个top节点对应的树的编号

void getd(int u,int f) {
    dep[u]=dep[f]+1;
    for(int i = head[u]; i; i=e[i].nxt) {
        int v=e[i].v;
        if(v==f||v==rt[0]||v==rt[1]) continue;
        //断边就是这个意思
        getd(v,u);
    }
}

void count(int u,int f,int tp,int root) {
    top[u]=tp;
    if(dep[u]==min_l) cnt[tp]++,cnt[root]++;
    for(int i = head[u]; i; i=e[i].nxt) {
        int v=e[i].v;
        if(v==f||v==rt[0]||v==rt[1]) continue;
        count(v,u,tp,root);
    }
}

void init() {
    for(int i=head[rt[0]]; i ; i=e[i].nxt) {
        int v=e[i].v;
        if(v==rt[1]) continue;
        count(v,rt[0],v,rt[0]);
        fa[v]=0;
    }
    if(rt[0]==rt[1]) return;
    for(int i=head[rt[1]]; i ; i=e[i].nxt) {
        int v=e[i].v;
        if(v==rt[0]) continue;
        count(v,rt[1],v,rt[1]);
        fa[v]=1;
    }
}

ll d[maxn];
ll ans;

void sol(int u,int f) {
    for(int i = head[u]; i; i=e[i].nxt) {
        int v=e[i].v;
        if(v==f||v==rt[0]||v==rt[1]) continue;
        sol(v,u);
        d[u]+=d[v];
        if(d[u]>M||d[u]<0) d[u]%=M;
        if(d[u]<0) d[u]+=M;
    }
    if(rt[0]==rt[1]&&u==rt[1]) d[u]*=(M+1)>>1,d[u]%=M;
    //2在模99824353意义下的乘法逆元比较特殊,其他的不知道可不可以
    ans+=d[u]*(k==1?1:d[u]);
    if(ans>M||ans<0) ans%=M;
    if(ans<0) ans+=M;
}

int main() {
    n=fr(),k=fr();
    if(n<=2) {printf("%d\n",n);return 0;}
    //注意:这样写小于等于二会输出0,虽然没有这个数据,但是写一下
    for(int i = 1; i <= n-1; i++) {
        int u=fr(),v=fr();
        ade(u,v);
        ade(v,u);
    }
    getr();
    if(rt[0]==rt[1]) min_l++;
    getd(rt[0],0);getd(rt[1],0);
    init();
    for(int i = 1; i <= n; i++) {
        if(dep[i]==min_l) {
            d[i]+=cnt[rt[fa[top[i]]^1]];
            if(rt[0]==rt[1]) d[i]-=cnt[top[i]];
        }
    }
    sol(rt[0],0);if(rt[1]!=rt[0]) sol(rt[1],0);
    printf("%lld\n",ans);
    return 0;
}

闲话

蒟蒻经历了漫长的鏖战终于胜利,本篇算是写一些自己的理解了。

如果觉得有用,点个赞吧!

posted @ 2025-08-15 19:17  Tenil  阅读(9)  评论(0)    收藏  举报