G2023--最后一题(保存)

最后一题 题解

看完题后可以发现显然是树形dp。

先考虑没有 \(k\) 限制的情况:

由于必须选 \(\text{LCA}\),考虑状态与 \(\text{LCA}\) 有关。

合法点的路径有两种情况:

  1. 路径端点之一是 \(\text{LCA}\);
  2. 路径端点都不是 \(\text{LCA}\)

再考虑到 \(\text{LCA}\) 的颜色可以不同,为了方便转移,再用 \(\text{LCA}\) 的颜色区分。

状态为:

\(f_{now,j,k}\) 表示 \(\text{LCA}\)\(now\),端点颜色是 \(j\)\(\text{LCA}\) 是否是端点。\((j\in\{0,1\} ,k\in\{0,1\})\)

\(son_{now}\)\(now\) 的儿子结点集合,\(tree_{now}\)\(now\) 的子树结点集合。

转移方程是:

\[f_{now,j,0}=1+\sum_{v\in son_{now}}\sum\limits_{u\in{tree_{v}}}f_{u,!j,0}\times (\text{dep(u)}-\text{dep(now)}) \]

\[f_{now,j,1}=\sum\limits_{v1\in{son_{now}}}\sum\limits_{v2\in son_{now} v1 \neq v2}\sum_{u1\in tree_{v1}}\sum_{u2\in tree_{v2}}f_{u1,!j,0}\times f_{u2,!j,0}\times (\text{dep(u1)}-\text{dep(now)})\times (\text{dep(u2)}-\text{dep(now)}) \]

但这样太慢了,是 \(n^2\) 的,考虑优化一下。

注意到 \(\sum f_{u,j,0}\times (\text{dep(u)}-\text{dep(now)})\) 出现多次,不妨把它分别设为 \(h_{now,j,0}=\sum f_{u,j,0}\times (\text{dep(u)})\)\(h_{now,j,1}=\sum f_{u,j,0}\)

\(\sum f_{u,j,0}\times (\text{dep(u)}-\text{dep(now)})\) 就可以写为:

\[h_{now,j,0}-\text{dep(now)}\times h_{now,j,1} \]

(设 \(g_{now,j}=h_{now,j,0}-\text{dep(now)}\times h_{now,j,1}\)

这都很好转移。

\[f_{now,j,0}=1+\sum\limits_{u\in{son_{now}}} g_{u,j} \]

现在 \(f_{now,j,0}\) 是可以 \(O(n)\) 做的了。

再是 \(f_{now,j,1}\)。这东西可以用前缀和优化。

\(sum_j=\sum\limits_{u\in{son_{now}}} g_{u,j}\)
\(f_{now,j,1}=\sum\limits_{u\in son_{now}} (sum_{!j}-g_{u,!j})\times sum_{!j}\)

但这样还是不对,因为一条路径在两棵子树都记了一次,所以要除以二。

\[f_{now,j,1}=\frac{\sum\limits_{u\in son_{now}} (sum_{!j}-g_{u,!j})\times sum_{!j}}{2} \]

这样当 \(k=n\) 时,就有一个 \(O(n)\) 做法了。

但是考虑到 \(k\) 的限制,沿用刚才思路,需要修改的其实只有一个状态的转移 \(h\) 。而其他状态什么都不变。

由于与深度有关系,可以使用线段树合并维护 \(h\) 的修改,用四棵线段树,表示 \(h_{now,i,j}\)

\[h_{now,j,0}=\sum f_{u,j,0}\times (\text{dep(u)})\ \ (\text(\text{dep(u)-dep(now)}\le k) \]

\[h_{now,j,1}=\sum f_{u,j,0}\ \ (\text(\text{dep(u)-dep(now)}\le k) \]

\(f\) 动态插入到线段树中,更新 \(h\) 的值即可。

时间复杂度 \(O(n\log n)\)

注意取模。

代码

#include <bits/stdc++.h>
using namespace std;
#define int long long
const int mod = 1000000007;
const int inv2 = 500000004;
int n, col[100010], k;
struct edge
{
    int from, to;
} e[100010 << 1];
int head[100010], S;
void addedge(int x, int y)
{
    e[++S].to = y;
    e[S].from = head[x], head[x] = S;
}
int f[100010][2][2];
int g[100010][2];
int h[100010][2][2];
int dep[100010];
struct sgt
{
    struct node
    {
        int sum1, sum2;
    } tree[100010 << 3];
    int ls[100010 << 3], rs[100010 << 3], root[100010 << 3];
    int cnt;
#define mid ((l + r) >> 1)
    node pushup(node L, node R)
    {
        node ret;
        ret.sum1 = (L.sum1 + R.sum1) % mod;
        ret.sum2 = (L.sum2 + R.sum2) % mod;
        return ret;
    }
    void add(int &now, int l, int r, int x, int k)
    {
        if (!now)
            now = ++cnt;
        if (l == r)
        {
            tree[now].sum1 += k * x % mod;
            tree[now].sum2 += k;
            tree[now].sum1 %= mod;
            tree[now].sum2 %= mod;
            return;
        }
        if (mid >= x)
            add(ls[now], l, mid, x, k);
        else
            add(rs[now], mid + 1, r, x, k);
        tree[now] = pushup(tree[ls[now]], tree[rs[now]]);
    }
    int merge(int a, int b, int l, int r)
    {
        if (!a || !b)
            return a ^ b;
        if (l == r)
        {
            tree[a].sum1 += tree[b].sum1;
            tree[a].sum2 += tree[b].sum2;
            tree[a].sum1 %= mod;
            tree[a].sum2 %= mod;
            return a;
        }
        ls[a] = merge(ls[a], ls[b], l, mid);
        rs[a] = merge(rs[a], rs[b], mid + 1, r);
        tree[a] = pushup(tree[ls[a]], tree[rs[a]]);
        return a;
    }
    int ask1(int now, int l, int r, int x, int y)
    {
        if (!now)
            return 0;
        if (l >= x && r <= y)
        {
            return tree[now].sum1;
        }
        int ret = 0;
        if (mid >= x)
            ret += ask1(ls[now], l, mid, x, y);
        if (mid < y)
            ret += ask1(rs[now], mid + 1, r, x, y);
        return ret % mod;
    }
    int ask2(int now, int l, int r, int x, int y)
    {
        if (!now)
            return 0;
        if (l >= x && r <= y)
        {
            return tree[now].sum2;
        }
        int ret = 0;
        if (mid >= x)
            ret += ask2(ls[now], l, mid, x, y);
        if (mid < y)
            ret += ask2(rs[now], mid + 1, r, x, y);
        return ret % mod;
    }
} T0, T1;
void dfs(int now, int fa)
{
    dep[now] = dep[fa] + 1;
    int sum0 = 0, sum1 = 0;
    for (int i = head[now]; i; i = e[i].from)
    {
        int u = e[i].to;
        if (u == fa)
            continue;
        dfs(u, now);
        g[u][0] = T0.ask1(T0.root[u], 1, n, dep[u], min(dep[u] + k - 1, n)) - 
        dep[now] * T0.ask2(T0.root[u], 1, n, dep[u], min(dep[u] + k - 1, n));
        g[u][1] = T1.ask1(T1.root[u], 1, n, dep[u], min(dep[u] + k - 1, n)) - 
        dep[now] * T1.ask2(T1.root[u], 1, n, dep[u], min(dep[u] + k - 1, n));
        g[u][0] = (g[u][0] + mod) % mod;
        g[u][1] = (g[u][0] + mod) % mod;
        sum0 += g[u][0];
        sum0 %= mod;
        sum1 += g[u][1];
        sum1 %= mod;
        h[now][0][0] += h[u][0][0];
        h[now][0][0] %= mod;
        h[now][0][1] += h[u][0][1];
        h[now][0][1] %= mod;
        h[now][1][0] += h[u][1][0];
        h[now][1][0] %= mod;
        h[now][1][1] += h[u][1][1];
        h[now][1][1] %= mod;
    }
    if (col[now] == 1 || col[now] == -1)
    {
        f[now][1][0] = sum0 + 1;
        T1.add(T1.root[now], 1, n, dep[now], f[now][1][0]);
        for (int i = head[now]; i; i = e[i].from)
        {
            int u = e[i].to;
            if (u == fa)
                continue;
            f[now][1][1] += (sum0 - g[u][0]) * g[u][0] % mod;
            f[now][1][1] %= mod;
            T1.root[now] = T1.merge(T1.root[now], T1.root[u], 1, n);
        }
        f[now][1][1] = f[now][1][1] * inv2 % mod;
    }
    if (col[now] == 0 || col[now] == -1)
    {
        f[now][0][0] = sum1 + 1;
        T0.add(T0.root[now], 1, n, dep[now], f[now][0][0]);
        for (int i = head[now]; i; i = e[i].from)
        {
            int u = e[i].to;
            if (u == fa)
                continue;
            f[now][0][1] += (sum1 - g[u][1]) * g[u][1] % mod;
            f[now][0][1] %= mod;
            T0.root[now] = T0.merge(T0.root[now], T0.root[u], 1, n);
        }
        f[now][0][1] = f[now][0][1] * inv2 % mod;
    }
}
signed main()
{
    cin >> n >> k;
    for (int i = 1; i <= n; i++)
        cin >> col[i];
    for (int i = 1; i < n; i++)
    {
        int x, y;
        cin >> x >> y;
        addedge(x, y);
        addedge(y, x);
    }
    dfs(1, 1);
    int ans = 0;
    for (int i = 1; i <= n; i++)
        ans = (ans + f[i][0][0] + f[i][0][1] + f[i][1][0] + f[i][1][1]) % mod;
    cout << (ans + mod) % mod;
}
posted @ 2023-07-27 14:35  星河倒注  阅读(30)  评论(0)    收藏  举报