ABC248G GCD cost on the tree

AT 洛谷

  • 给出一棵 \(n\) 个点的树,点 \(i\) 有点权 \(a_i\)。定义 \(u,v\) 间的路径的权值 \(C(u,v)\) 为,这条路径上的点数乘上路径上所有点的 \(\gcd\)

  • 形式化的,若点 \(u\) 到点 \(v\) 的路径为 \(p_1,p_2,\dots,p_k\,(p_1=u,p_k=v)\),则 \(C(u,v)=\gcd(a_{p_1},a_{p_2},\dots,a_{p_k})\)

  • \(\left(\sum\limits_{i=1}^{n-1}\sum\limits_{j=i+1}^n C(i,j)\right)\bmod \color{red}\boldsymbol{998244353}\)。某个【数据删除】把模数看成了 \(10^9+7\),是谁我不说。

  • \(n,\max\limits_{i=1}^n a_i\le 10^5\)

约定:记 \(d(n)\)\(n\) 的正因子个数。形式化的,\(d(n)=\sum\limits_{i=1}^n[i\,|\,n]\)

直接大力点分治好吧,类似题 CF1101DCF990G

考虑如何计算经过当前分治中心 \(rt\) 的路径的贡献。我们用从 \(rt\) 不同子树内的点到分治中心的链,拼接起来(\(rt\) 拼接的时候去重),得到一条路径。也可以单链成为一条路径。这都是套路。为了不重不漏地拼接,考虑用后面的子树的链匹配之前的子树的链。

\(\gcd\) 有一个众所周知的性质,即 \(\gcd(a,b,c)=\gcd(\gcd(a,b),\gcd(b,c))\)。我们用 \(val_u\) 表示 \(rt\)\(u\) 的链上的点的 \(\gcd\)\(dep_u\) 表示 \(rt\)\(u\) 的边数。则对于 \(rt\) 子树内的点 \(u,v\)\(C(u,v)=\gcd(val_u,val_v)\times (dep_u+dep_v\boldsymbol{+1})\)。考虑如何快速计算,用 std::unordered_map__gnu_pbds::gp_hash_table 维护桶 \(mp_{1_x}\) 表示之前子树内,\(val\) 值为 \(x\) 的链的 \(dep\) 值和;桶 \(mp_{2_x}\) 表示之前子树内,\(val\) 值为 \(x\) 的链的数量。我们枚举当前子树内的点 \(u\),再枚举之前的链的 \(\gcd\) 值,它们一定在同理。即遍历桶内的元素,设遍历到的下标为 \(x\),则该种下标产生的贡献为 \([(dep_u+1)\times mp_{2_x}+mp_{1_x}]\times \gcd(x,val_u)\)。统计完右链在当前子树内的路径后,更新桶。

单链的情况类似求,然后继续分治下去。看似很暴力,但是你注意到桶内的元素一定是 \(a_{rt}\) 的因数,我们一共分治 \(\mathcal{O}(\log n)\) 层,每层每个点被遍历 \(\mathcal{O}(1)\) 次,遍历的时间为 \(\mathcal{O}(d(a_{rt}))\)。这样一来时间复杂度为 \(\mathcal{O}(n\log n \cdot \max\limits_{i=1}^nd(a_i))\)

首先一个数 \(x\) 的正因子个数不超过 \(\mathcal{O}(\sqrt{x})\)。设值域为 \(V\),我们可以认为时间复杂度为 \(\mathcal{O}(n\log n\cdot \sqrt{|V|})\)。由于值域只有 \(10^5\),配合 \(8\) 秒的时限,可以通过。确实很暴力

进一步思考,我们发现在这个值域内,\(\max\limits_{i=1}^nd(a_i)\) 不超过 \(128\)。你可以自己验证一下:

#include <bits/stdc++.h>
using namespace std;
int main() {
    int ans = 0;
    for (int i = 1; i <= 100000; ++i) {
        int tot = 0;
        for (int j = 1; j * j <= i; ++j) {
            if (!(i % j)) tot += 2;
            if (j * j == i) --tot;
        }
        ans = max(ans, tot);
    }
    cout << ans;
}

所以我们把 \(\max\limits_{i=1}^nd(a_i)\) 看作 \(128\) 的话,会比 \(\mathcal{O}(\sqrt{|V|})\) 优秀一点。好好好,\({128}\) 倍大常数优秀 \(\text{polylog}\) 做法。

不过为了保证复杂度正确,切记 \(\boldsymbol{val}\) 数组指的是某个点到分治中心 \(\boldsymbol{rt}\) 的链上的点权值的最大公约数,若定义成到该子树根的链上的点权值的最大公约数,那么每个点在一次分治中枚举之前链的实际次数会变成,所有 \(\boldsymbol{rt}\) 的儿子的权值的约数集合的并集的大小,枚举次数就无法得到保证了

总结一下,时间复杂度为 \(\mathcal{O}\left(n\log n \cdot \max\limits_{i=1}^{|V|} d(i)\right)\),空间复杂度为 \(\mathcal{O}(n+|V|)\),可以接受。直接喜提你谷最劣解

提交记录

#pragma GCC optimize("Ofast")
#include <bits/stdc++.h>
#define fi first 
#define se second 
using namespace std; typedef long long ll; const int N = 1e5 + 5, M = 998244353;
int n, a[N], dep[N], val[N], siz[N], maxn[N], rt, tot, stk[N], top, p[N], cnt; 
vector<int> g[N]; bool vis[N]; ll ans; unordered_map<int, ll> mp1, mp2;
template<class T> void read(T &x) {
    x = 0; T f = 1; char c = getchar();
    for (; !isdigit(c); c = getchar()) if (c == '-') f = -1;
    for (; isdigit(c); c = getchar()) x = (x << 3) + (x << 1) + c - 48; x *= f;
}
template<class T> void write(T x) {
    if (x > 9) write(x / 10); putchar(x % 10 + 48);
}
template<class T> void print(T x, char ed = '\n') {
    if (x < 0) putchar('-'), x = -x; write(x), putchar(ed);
}
void gravity(int u, int fa) {
    siz[u] = 1; maxn[u] = 0;
    for (int v : g[u]) {
        if (v == fa || vis[v]) continue; gravity(v, u);
        siz[u] += siz[v]; maxn[u] = max(maxn[u], siz[v]);
    }
    maxn[u] = max(maxn[u], tot - siz[u]); if (maxn[u] < maxn[rt]) rt = u;
}
void get(int u, int fa) {
    p[++cnt] = u;
    for (int v : g[u]) {
        if (v == fa || vis[v]) continue;
        val[v] = __gcd(val[u], a[v]); dep[v] = dep[u] + 1; get(v, u);
    }
}
void divide(int u) {
    vis[u] = 1;
    for (int v : g[u]) {
        if (vis[v]) continue; dep[v] = 1, val[v] = __gcd(a[u], a[v]); cnt = 0; get(v, u);
        for (int i = 1; i <= cnt; ++i) {
            int x = p[i];
            for (auto j : mp1) {
                ll w = __gcd(val[x], j.fi);
                ans += (1ll * (dep[x] + 1) * mp2[j.fi] + j.se) * w; ans %= M;
            }
        }
        for (int i = 1; i <= cnt; ++i) {
            int x = p[i]; ++mp2[val[x]]; mp1[val[x]] += dep[x]; 
            stk[++top] = val[x];
        }
    }
    for (auto j : mp1) {
        ll w = j.fi; ans += (j.se + mp2[j.fi]) * w; ans %= M;
    }
    for (; top; --top) mp1[stk[top]] = mp2[stk[top]] = 0;
    for (int v : g[u]) {
        if (vis[v]) continue; 
        tot = siz[v]; rt = 0; gravity(v, 0); gravity(rt, 0); divide(rt);
    }
}
signed main() {
    read(n); for (int i = 1; i <= n; ++i) read(a[i]);
    for (int i = 1, u, v; i < n; ++i)
        read(u), read(v), g[u].emplace_back(v), g[v].emplace_back(u);
    tot = n; maxn[0] = M; gravity(1, 0); gravity(rt, 0);
    divide(rt); print(ans); return 0;
}
posted @ 2023-09-28 16:50  lzyqwq  阅读(27)  评论(0)    收藏  举报