CF434E 题解

定义 $t(p_1 \to p_2) = z_{0}\times k^{0}+z_{1}\times k^{1} + \cdots + z_{l-1} \times k^{l-1}$,$f(r) = [r = x]$
则答案为$$\sum_{p_1 \in n}\sum_{p_2 \in n} \sum_{p_3 \in n} [f(t(p_1\to p_2))=f(t(p_2 \to p_3))=f(t(p_1 \to p_3))]$$
计 $R(p_1,p_2) = f(t(p_1\to p_2))$,考虑等价变形
$$\sum_{p_1 \in n}\sum_{p_2 \in n} \sum_{p_3 \in n} R(p_1, p_2)R(p_2,p_3)R(p_1,p_3)+(1-R(p_1, p_2))(1-R(p_2,p_3))(1-R(p_1,p_3))$$
将右边的柿子拆开:
$$1-R(p_1,p_2)-R(p_1,p_3)-R(p_2,p_3)+R(p_1,p_2)R(p_1,p_3)+R(p_1,p_2)R(p_2,p_3)+R(p_1,p_3)R(p_2,p_3)$$
记 $a_u =\sum_{v \in n} R(u,v)$,$b_u=\sum_{v \in n} R(v,u)$
上面的和式可以化为:$$n^3-\frac{3n}{2}\sum (a_i+b_i)+\sum (a_i^2+b_i^2+a_i \times b_i)$$
实质上 $\sum a_i = \sum b_i$。
处理 $a_u$ 和 $b_u$ 可以使用点分治。
具体来说:对于一个分治中心 $u$,处理出 $t(u,v)$,$t(v,u)$,记为 $c_u,d_v$。有 $c_u +d_u \times k^{len_{c_u}+1}=x$。我们把 $d_i$ 和 $\frac{x-c_i}{k^{len_{c_i}+1}}$ 分别丢进哈希表里,每次枚举 $c_i$ 和 $d_i$,$O(1)$ 查询。
使用 unordered_map 建议使用 while $+$ erase 清空,使用 clear 貌似会被卡?
建议预处理出逆元和 $k$ 的 $i$ 次方,不然复杂度是假的。

#include<bits/stdc++.h>
#include<bits/extc++.h>
using namespace std;
#define mp make_pair
typedef long long ll; 
const int N = 1e6 + 500;
int read() {
    int x = 0;char c = getchar();
    while(!isdigit(c)) c = getchar();
    while(isdigit(c)) x = (x << 3) + (x << 1) + (c ^ 48), c = getchar();
    return x;
}
struct edge{int next,to;}e[N << 1];
int n, sum, rt, y, k, r, nnv, ccv;
int cnt, head[N], maxp[N], si[N], val[N], vis[N];
ll fk[N], ink[N];
ll a[N], b[N], dl[N]; 
pair<int, int> dd[N];
pair<ll, pair<int, int> > cc[N]; 
unordered_map<ll, ll> Q, W; 
void add(int f, int t) {e[++cnt] = edge{head[f], t}, head[f] = cnt;} 
void get_core(int u, int f) {
    maxp[u] = 0;
    si[u] = 1;
    for(int i = head[u];i;i = e[i].next) {
        int v = e[i].to;
        if(v == f || vis[v]) continue;
        get_core(v, u);
        si[u] += si[v];
        maxp[u] = max(maxp[u], si[v]);
    }
    maxp[u] = max(maxp[u], sum - si[u]);
    if(!rt || maxp[u] < maxp[rt]) rt = u; 
}
ll km(ll a, ll b){
    ll ans = 1, bs = a;
    while(b) {if(b & 1) ans = (ans * bs) % y;bs = (bs * bs) % y;b >>= 1;} 
    return ans;
}
void get_dis(int u, int f, ll len, int dep, bool jd = false) {
    len = (len * k  + val[u]) % y;
    cc[++ccv] = mp(len, mp(dep, u)); // cc : v -> u
    dl[u] = (dl[f] + fk[dep + 1] * val[u] ) % y; // dd : u -> v
    dd[++nnv] = mp(dl[u], u); 
    for(int i = head[u];i; i = e[i].next) {
        int v = e[i].to;
        if(v == f || vis[v]) continue;
        get_dis(v ,u, len, dep + 1, jd);
    }
}
int qr(int vl, int kt) {
    int c = r - vl;
    return (c * ink[kt] % y + y) % y;
}
void tadd(int opt) {
    for (int dk = 1; dk <= nnv; dk++) W[dd[dk].first] += opt;
    for (int dk = 1; dk <= ccv; dk++) Q[qr(cc[dk].first, cc[dk].second.first + 1)] += opt;
}
void calc(int u) {
    while(!Q.empty()) Q.erase(Q.begin());
    while(!W.empty()) W.erase(W.begin());
    dl[u] = val[u];
    Q[r]++, W[val[u]]++;
    ccv = nnv = 0;
    for(int i = head[u];i ;i = e[i].next) {
        int v = e[i].to;
        if(vis[v]) continue;
        get_dis(v, u, 0, 0);
    }
    tadd(1);
    a[u] += W[r];
    b[u] += Q[val[u]];
    for(int i = head[u];i;i = e[i].next) {
        int v = e[i].to;
        if(vis[v]) continue;
        ccv = nnv = 0;
        get_dis(v, u, 0, 0);
        tadd(-1);
        for(int dk = 1;dk <= ccv;dk++) 
            a[cc[dk].second.second] += W[qr(cc[dk].first, cc[dk].second.first + 1)];
        for(int dk = 1;dk <= nnv;dk++) b[dd[dk].second] += Q[dd[dk].first];
        tadd(1);
    }
}
void solve(int u) {
    vis[u] = 1;
    calc(u);
    for(int i = head[u];i;i = e[i].next) {
        int v = e[i].to;
        if(vis[v]) continue;
        sum = si[v], rt = 0;
        get_core(v, u);
        solve(rt);
    }
}
int main(){
    scanf("%d %d %d %d", &n, &y, &k, &r); r %= y;
    for(int i = 1;i <= n;i++) val[i] = read();
    for(int i = 1;i < n;i++) {
        int u, v;
        u = read(), v = read();
        add(u, v), add(v, u);
    } fk[0] = 1;
    for(int i = 1;i <= n;i++) fk[i] = (fk[i - 1] * k) % y;
    for(int i = 1;i <= n;i++) ink[i] = km(fk[i], y - 2);
    sum = n;
    get_core(1, 0);
    solve(rt);
    ll ans = (ll)n * n * n;
    for(int i = 1;i <= n;i++) ans -= 3 * n * a[i];
    for(int i = 1;i <= n;i++) 
        ans += (ll)a[i] * a[i] + (ll)b[i] * b[i] + (ll)a[i] * b[i];    
    printf("%lld", ans);
    return 0;
}
posted @ 2023-10-16 12:49  Saka_Noa  阅读(15)  评论(0)    收藏  举报  来源