QOJ 1288 Tokens on the Tree 题解

Link

很一言难尽的一个题。

毛营的题是不是总是有了题解之后考虑怎么拼出一个题啊???

我们来考虑比较 人类 的思路,注意到当 \(w = 1, b = 1\)\(f(w, b) = 1\),那什么时候 \(f(w, b) = 1\) 呢?手模一下是很好发现的,当且仅当存在一个三度点时我们可以置换某一个节点之后交换颜色位置。

\(w = b\),则 \(f(w, w) = 1\) 的情况为有一个节点同时存在 \(3\) 个大小 \(\geq w\) 的子树。一般地,令 \(w \geq b\),如果存在一个节点 \(u\) 满足一个子树大小 \(\geq b\),两个子树大小 \(\geq w\)\(f(w, b) = 1\)。一个关键的结论是对于一个固定的 \(w\),有一个 \(f(w, b) = 1\)\(b\) 的一个前缀。

尝试处理出 \(g(w)\) 表示最大的 \(b\) 使得 \(f(w, b) = 1, w \geq b\)\(O(n)\) 统计出所有 \(f(w, b) = 1\) 的答案和。

对于 \(f(w, b) \neq 1\) 的情况,我们容易想到去枚举 \((g_w, w]\),由于 \(f(w, b) \neq 1\),意味着两个颜色的棋子之间不能交换,此时有边 \((u, v), v \in son_u\) 使得 \(siz_u \geq w, siz_v \geq b\),同时 \(son_u\) 中除了 \(v\) 之外的子树的 \(siz\) 都要 \(\lt w\),此时 \(w\) 不能完全丢进一个子树里,意味着 \(u\) 上一定是 \(w\),而 \(b\) 无法交换过来。

对于每一个合法的 \(w, u\) 合法的 \(b\) 的范围都会对答案加上一个等比数列。显然枚举 \(u\) 时合法的 \(w\) 都有自己的范围,维护一个树状数组实现区间加单点查可以从 \(O(n^3) \to O(n^2) \to O(n \log n)\)

代码超级难写,有点爱上这种 moddadd 的写法了。具体看看怎么算贡献

对于 \(f(w, b) = 1\) 的贡献,由于:

\[\sum_{w, b} f(w, b)wb = \sum_{w \geq b} f(w, b)wb + \sum_{b \geq w} f(w, b)wb - \sum_{w = b} f(w, w)w^2 \]

我们定义 \(g_w = \max \{ b, f(w, b) = 1 \}\)

\(f(w, b) = f(b, w)\),故:

\[2 \sum_{w = 1}^{n} \sum_{b = 1}^{\min(w, g_w)} wb - \sum_{w = 1}^{n} [g_w \geq w] w^2 \]

分类讨论,当 \(g_w \geq w\),此时 \(b\) 可以取到 \(w\)

\[2 \sum_{b = 1}^{w} wb - w^2 = w^3 \]

\(g_w \lt w\),此时:

\[2 \sum_{b = 1}^{g_w} wb = w g_w(g_w + 1) \]

再来看 \(f(w, b) \neq 1\) 的贡献,我们分类处理两种边 \(son \to fa, fa \to son\),拆贡献用差分数组维护满足条件的边数,这里要维护一下通过差分数组来累加统计答案。这部分好麻烦的。

#include <bits/stdc++.h>

using i64 = long long;

constexpr int N = 2e5 + 7;
constexpr int P = 1e9 + 7;

int mul(int x, int y) { return 1ll * x * y % P; }

void modadd(int& x, int y) { x = (x + y >= P ? x + y - P : x + y); }

void modsub(int& x, int y) { x = (x - y < 0 ? x - y + P : x - y); }

void modmul(int& x, int y) { x = mul(x, y); }

int t, n, ans;
int s[N], fa[N], siz[N], mxy[N];

std::pair<int, int> msx[N][3];
std::vector<int> adj[N];
std::vector<std::pair<int, int>> g[N];

// #define DEBUG 1  // 调试开关
struct IO {
#define MAXSIZE (1 << 20)
#define isdigit(x) (x >= '0' && x <= '9')
    char buf[MAXSIZE], *p1, *p2;
    char pbuf[MAXSIZE], *pp;
#if DEBUG
#else
    IO() : p1(buf), p2(buf), pp(pbuf) {}

    ~IO() { fwrite(pbuf, 1, pp - pbuf, stdout); }
#endif
    char gc() {
#if DEBUG // 调试,可显示字符
        return getchar();
#endif
        if (p1 == p2)
            p2 = (p1 = buf) + fread(buf, 1, MAXSIZE, stdin);
        return p1 == p2 ? ' ' : *p1++;
    }

    bool blank(char ch) {
        return ch == ' ' || ch == '\n' || ch == '\r' || ch == '\t';
    }

    template <class T>
    void read(T &x) {
        double tmp = 1;
        bool sign = false;
        x = 0;
        char ch = gc();
        for (; !isdigit(ch); ch = gc())
            if (ch == '-')
                sign = 1;
        for (; isdigit(ch); ch = gc())
            x = x * 10 + (ch - '0');
        if (ch == '.')
            for (ch = gc(); isdigit(ch); ch = gc())
                tmp /= 10.0, x += tmp * (ch - '0');
        if (sign)
            x = -x;
    }

    void read(char *s) {
        char ch = gc();
        for (; blank(ch); ch = gc())
            ;
        for (; !blank(ch); ch = gc())
            *s++ = ch;
        *s = 0;
    }

    void read(char &c) {
        for (c = gc(); blank(c); c = gc())
            ;
    }

    void push(const char &c) {
#if DEBUG // 调试,可显示字符
        putchar(c);
#else
        if (pp - pbuf == MAXSIZE)
            fwrite(pbuf, 1, MAXSIZE, stdout), pp = pbuf;
        *pp++ = c;
#endif
    }

    template <class T>
    void write(T x) {
        if (x < 0)
            x = -x, push('-'); // 负数输出
        static T sta[35];
        T top = 0;
        do {
            sta[top++] = x % 10, x /= 10;
        } while (x);
        while (top)
            push(sta[--top] + '0');
    }

    template <class T>
    void write(T x, char lastChar) {
        write(x), push(lastChar);
    }
} io;

struct Fenwick {
    int tr[N << 1];

    int lowbit(int x) { return x & -x; }

    void clear() { memset(tr, 0, sizeof(tr)); }

    void upd(int i, int x) {
        for (; i <= n; i += lowbit(i))
            tr[i] += x;
    }

    int qry(int i) {
        int res = 0;
        for (; i; i -= lowbit(i))
            res += tr[i];
        return res;
    }
} bit;

void clear() {
    for (int i = 0; i <= n + 1; i++) {
        siz[i] = 1; s[i] = mxy[i] = 0;
        g[i].clear(); adj[i].clear(); bit.tr[i] = 0;
    }
}

void solve() {
    io.read(n); clear();
    for (int i = 2; i <= n; i++) {
        io.read(fa[i]);
        adj[fa[i]].push_back(i);
    }
    for (int i = n; i >= 1; i--) {
        siz[fa[i]] += siz[i];
    }
    for (int u = 1; u <= n; u++) {
        std::vector<std::pair<int, int>> vec;
        if (u != 1)
            vec.push_back({n - siz[u], fa[u]});
        for (int v : adj[u])
            vec.push_back({siz[v], v});
        std::sort(vec.begin(), vec.end(), std::greater<std::pair<int, int>>());
        if (vec.size() >= 3) {
            int x = std::min(vec[0].first, vec[1].first);
            int y = vec[2].first;
            mxy[x] = std::max(mxy[x], y);
        }
        for (int i = 0; i < 3; i++)
            msx[u][i] = {0, 0};
        for (int i = 0; i < vec.size(); i++)
            msx[u][i] = vec[i];
    }
    for (int x = n; x >= 1; x--) {
        mxy[x - 1] = std::max(mxy[x - 1], mxy[x]);
    }
    ans = 0;
    for (int x = 1; x <= n; x++) {
        if (mxy[x] >= x) {
            modadd(ans, mul(2, mul(x, 1ll * (1 + x) * x / 2 % P)));
            modsub(ans, mul(x, x));
        } else {
            modadd(ans, mul(2, mul(x, 1ll * (1 + mxy[x]) * mxy[x] / 2 % P)));
        }
    }
    for (int v = 2; v <= n; v++) {
        int mn = 1, mx = n - siz[v];
        if (msx[fa[v]][0].second == v) {
            mn = std::max(mn, msx[fa[v]][1].first + 1);
        } else {
            mn = std::max(mn, msx[fa[v]][0].first + 1);
        }
        mn = std::max(mn, n - siz[fa[v]] + 1);
        int L = std::max(siz[v] + 1, mn);
        int R = mx;
        int val = (siz[v] + mul(siz[v], siz[v])) % P;
        if (L <= R) {
            modadd(ans, mul(1ll * (L + R) * (R - L + 1) / 2 % P, val));
        }
        if (mn <= std::min(siz[v], mx)) {
            modadd(s[mn], 1);
            modsub(s[std::min(siz[v], mx) + 1], 1);
        }
        if (std::max(siz[v] + 1, mn) <= mx) {
            g[std::max(siz[v] + 1, mn)].push_back({siz[v] - 1, 1});
            g[mx + 1].push_back({siz[v] - 1, P - 1});
        }
    }
    for (int u = 2; u <= n; u++) {
        int mn = 1, mx = siz[u];
        if (msx[u][0].second == fa[u]) {
            mn = std::max(mn, msx[u][1].first + 1);
        } else {
            mn = std::max(mn, msx[u][0].first + 1);
        }
        int L = std::max(n - siz[u] + 1, mn);
        int R = mx;
        int val = (n - siz[u] + mul(n - siz[u], n - siz[u])) % P;
        if (L <= R) {
            modadd(ans, mul(1ll * (L + R) * (R - L + 1) / 2 % P, val));
        }
        if (mn <= std::min(n - siz[u], mx)) {
            modadd(s[mn], 1);
            modsub(s[std::min(n - siz[u], mx) + 1], 1);
        }
        if (std::max(n - siz[u] + 1, mn) <= mx) {
            g[std::max(n - siz[u] + 1, mn)].push_back({n - siz[u] - 1, 1});
            g[mx + 1].push_back({n - siz[u] - 1, P - 1});
        }
    }
    for (int x = 1; x <= n; x++) {
        modadd(s[x], s[x - 1]);
        if (mxy[x] + 1 <= x) {
            modadd(ans, mul(s[x], mul(x, 1ll * (mxy[x] + 1 + x) * (x - mxy[x]) % P)));
            modsub(ans, mul(s[x], mul(x, x)));
        }
    }
    for (int x = 1; x <= n; x++) {
        for (auto [i, v] : g[x]) {
            bit.upd(1, v);
            bit.upd(i + 1, (0 - v + P) % P);
        }
        modsub(ans, mul(bit.qry(mxy[x]), mul(x, (mul(mxy[x], mxy[x]) + mxy[x]) % P)));
    }
    io.write(ans, '\n');
}

int main() {
    io.read(t);
    while (t--) {
        solve();
    }
    return 0;
}

略微卡常。

posted @ 2025-11-12 10:08  起汐Moe  阅读(4)  评论(0)    收藏  举报