ABC248G GCD cost on the tree
给出一棵 \(n\) 个点的树,点 \(i\) 有点权 \(a_i\)。定义 \(u,v\) 间的路径的权值 \(C(u,v)\) 为,这条路径上的点数乘上路径上所有点的 \(\gcd\)。
形式化的,若点 \(u\) 到点 \(v\) 的路径为 \(p_1,p_2,\dots,p_k\,(p_1=u,p_k=v)\),则 \(C(u,v)=\gcd(a_{p_1},a_{p_2},\dots,a_{p_k})\)。
求 \(\left(\sum\limits_{i=1}^{n-1}\sum\limits_{j=i+1}^n C(i,j)\right)\bmod \color{red}\boldsymbol{998244353}\)。某个【数据删除】把模数看成了 \(10^9+7\),是谁我不说。
\(n,\max\limits_{i=1}^n a_i\le 10^5\)。
约定:记 \(d(n)\) 为 \(n\) 的正因子个数。形式化的,\(d(n)=\sum\limits_{i=1}^n[i\,|\,n]\)。
考虑如何计算经过当前分治中心 \(rt\) 的路径的贡献。我们用从 \(rt\) 不同子树内的点到分治中心的链,拼接起来(\(rt\) 拼接的时候去重),得到一条路径。也可以单链成为一条路径。这都是套路。为了不重不漏地拼接,考虑用后面的子树的链匹配之前的子树的链。
\(\gcd\) 有一个众所周知的性质,即 \(\gcd(a,b,c)=\gcd(\gcd(a,b),\gcd(b,c))\)。我们用 \(val_u\) 表示 \(rt\) 到 \(u\) 的链上的点的 \(\gcd\),\(dep_u\) 表示 \(rt\) 到 \(u\) 的边数。则对于 \(rt\) 子树内的点 \(u,v\),\(C(u,v)=\gcd(val_u,val_v)\times (dep_u+dep_v\boldsymbol{+1})\)。考虑如何快速计算,用 std::unordered_map 或 __gnu_pbds::gp_hash_table 维护桶 \(mp_{1_x}\) 表示之前子树内,\(val\) 值为 \(x\) 的链的 \(dep\) 值和;桶 \(mp_{2_x}\) 表示之前子树内,\(val\) 值为 \(x\) 的链的数量。我们枚举当前子树内的点 \(u\),再枚举之前的链的 \(\gcd\) 值,它们一定在同理。即遍历桶内的元素,设遍历到的下标为 \(x\),则该种下标产生的贡献为 \([(dep_u+1)\times mp_{2_x}+mp_{1_x}]\times \gcd(x,val_u)\)。统计完右链在当前子树内的路径后,更新桶。
单链的情况类似求,然后继续分治下去。看似很暴力,但是你注意到桶内的元素一定是 \(a_{rt}\) 的因数,我们一共分治 \(\mathcal{O}(\log n)\) 层,每层每个点被遍历 \(\mathcal{O}(1)\) 次,遍历的时间为 \(\mathcal{O}(d(a_{rt}))\)。这样一来时间复杂度为 \(\mathcal{O}(n\log n \cdot \max\limits_{i=1}^nd(a_i))\)。
首先一个数 \(x\) 的正因子个数不超过 \(\mathcal{O}(\sqrt{x})\)。设值域为 \(V\),我们可以认为时间复杂度为 \(\mathcal{O}(n\log n\cdot \sqrt{|V|})\)。由于值域只有 \(10^5\),配合 \(8\) 秒的时限,可以通过。确实很暴力。
进一步思考,我们发现在这个值域内,\(\max\limits_{i=1}^nd(a_i)\) 不超过 \(128\)。你可以自己验证一下:
#include <bits/stdc++.h>
using namespace std;
int main() {
int ans = 0;
for (int i = 1; i <= 100000; ++i) {
int tot = 0;
for (int j = 1; j * j <= i; ++j) {
if (!(i % j)) tot += 2;
if (j * j == i) --tot;
}
ans = max(ans, tot);
}
cout << ans;
}
所以我们把 \(\max\limits_{i=1}^nd(a_i)\) 看作 \(128\) 的话,会比 \(\mathcal{O}(\sqrt{|V|})\) 优秀一点。好好好,\({128}\) 倍大常数优秀 \(\text{polylog}\) 做法。
不过为了保证复杂度正确,切记 \(\boldsymbol{val}\) 数组指的是某个点到分治中心 \(\boldsymbol{rt}\) 的链上的点权值的最大公约数,若定义成到该子树根的链上的点权值的最大公约数,那么每个点在一次分治中枚举之前链的实际次数会变成,所有 \(\boldsymbol{rt}\) 的儿子的权值的约数集合的并集的大小,枚举次数就无法得到保证了。
总结一下,时间复杂度为 \(\mathcal{O}\left(n\log n \cdot \max\limits_{i=1}^{|V|} d(i)\right)\),空间复杂度为 \(\mathcal{O}(n+|V|)\),可以接受。直接喜提你谷最劣解。
#pragma GCC optimize("Ofast")
#include <bits/stdc++.h>
#define fi first
#define se second
using namespace std; typedef long long ll; const int N = 1e5 + 5, M = 998244353;
int n, a[N], dep[N], val[N], siz[N], maxn[N], rt, tot, stk[N], top, p[N], cnt;
vector<int> g[N]; bool vis[N]; ll ans; unordered_map<int, ll> mp1, mp2;
template<class T> void read(T &x) {
x = 0; T f = 1; char c = getchar();
for (; !isdigit(c); c = getchar()) if (c == '-') f = -1;
for (; isdigit(c); c = getchar()) x = (x << 3) + (x << 1) + c - 48; x *= f;
}
template<class T> void write(T x) {
if (x > 9) write(x / 10); putchar(x % 10 + 48);
}
template<class T> void print(T x, char ed = '\n') {
if (x < 0) putchar('-'), x = -x; write(x), putchar(ed);
}
void gravity(int u, int fa) {
siz[u] = 1; maxn[u] = 0;
for (int v : g[u]) {
if (v == fa || vis[v]) continue; gravity(v, u);
siz[u] += siz[v]; maxn[u] = max(maxn[u], siz[v]);
}
maxn[u] = max(maxn[u], tot - siz[u]); if (maxn[u] < maxn[rt]) rt = u;
}
void get(int u, int fa) {
p[++cnt] = u;
for (int v : g[u]) {
if (v == fa || vis[v]) continue;
val[v] = __gcd(val[u], a[v]); dep[v] = dep[u] + 1; get(v, u);
}
}
void divide(int u) {
vis[u] = 1;
for (int v : g[u]) {
if (vis[v]) continue; dep[v] = 1, val[v] = __gcd(a[u], a[v]); cnt = 0; get(v, u);
for (int i = 1; i <= cnt; ++i) {
int x = p[i];
for (auto j : mp1) {
ll w = __gcd(val[x], j.fi);
ans += (1ll * (dep[x] + 1) * mp2[j.fi] + j.se) * w; ans %= M;
}
}
for (int i = 1; i <= cnt; ++i) {
int x = p[i]; ++mp2[val[x]]; mp1[val[x]] += dep[x];
stk[++top] = val[x];
}
}
for (auto j : mp1) {
ll w = j.fi; ans += (j.se + mp2[j.fi]) * w; ans %= M;
}
for (; top; --top) mp1[stk[top]] = mp2[stk[top]] = 0;
for (int v : g[u]) {
if (vis[v]) continue;
tot = siz[v]; rt = 0; gravity(v, 0); gravity(rt, 0); divide(rt);
}
}
signed main() {
read(n); for (int i = 1; i <= n; ++i) read(a[i]);
for (int i = 1, u, v; i < n; ++i)
read(u), read(v), g[u].emplace_back(v), g[v].emplace_back(u);
tot = n; maxn[0] = M; gravity(1, 0); gravity(rt, 0);
divide(rt); print(ans); return 0;
}

浙公网安备 33010602011771号