G2023--最后一题(保存)
最后一题 题解
看完题后可以发现显然是树形dp。
先考虑没有 \(k\) 限制的情况:
由于必须选 \(\text{LCA}\),考虑状态与 \(\text{LCA}\) 有关。
合法点的路径有两种情况:
- 路径端点之一是 \(\text{LCA}\);
- 路径端点都不是 \(\text{LCA}\)。
再考虑到 \(\text{LCA}\) 的颜色可以不同,为了方便转移,再用 \(\text{LCA}\) 的颜色区分。
状态为:
\(f_{now,j,k}\) 表示 \(\text{LCA}\) 是 \(now\),端点颜色是 \(j\),\(\text{LCA}\) 是否是端点。\((j\in\{0,1\} ,k\in\{0,1\})\)。
设 \(son_{now}\) 为 \(now\) 的儿子结点集合,\(tree_{now}\) 为 \(now\) 的子树结点集合。
转移方程是:
但这样太慢了,是 \(n^2\) 的,考虑优化一下。
注意到 \(\sum f_{u,j,0}\times (\text{dep(u)}-\text{dep(now)})\) 出现多次,不妨把它分别设为 \(h_{now,j,0}=\sum f_{u,j,0}\times (\text{dep(u)})\) 和 \(h_{now,j,1}=\sum f_{u,j,0}\)。
\(\sum f_{u,j,0}\times (\text{dep(u)}-\text{dep(now)})\) 就可以写为:
(设 \(g_{now,j}=h_{now,j,0}-\text{dep(now)}\times h_{now,j,1}\))
这都很好转移。
现在 \(f_{now,j,0}\) 是可以 \(O(n)\) 做的了。
再是 \(f_{now,j,1}\)。这东西可以用前缀和优化。
设 \(sum_j=\sum\limits_{u\in{son_{now}}} g_{u,j}\),
则 \(f_{now,j,1}=\sum\limits_{u\in son_{now}} (sum_{!j}-g_{u,!j})\times sum_{!j}\)。
但这样还是不对,因为一条路径在两棵子树都记了一次,所以要除以二。
这样当 \(k=n\) 时,就有一个 \(O(n)\) 做法了。
但是考虑到 \(k\) 的限制,沿用刚才思路,需要修改的其实只有一个状态的转移 \(h\) 。而其他状态什么都不变。
由于与深度有关系,可以使用线段树合并维护 \(h\) 的修改,用四棵线段树,表示 \(h_{now,i,j}\) 。
将 \(f\) 动态插入到线段树中,更新 \(h\) 的值即可。
时间复杂度 \(O(n\log n)\)。
注意取模。
代码
#include <bits/stdc++.h>
using namespace std;
#define int long long
const int mod = 1000000007;
const int inv2 = 500000004;
int n, col[100010], k;
struct edge
{
int from, to;
} e[100010 << 1];
int head[100010], S;
void addedge(int x, int y)
{
e[++S].to = y;
e[S].from = head[x], head[x] = S;
}
int f[100010][2][2];
int g[100010][2];
int h[100010][2][2];
int dep[100010];
struct sgt
{
struct node
{
int sum1, sum2;
} tree[100010 << 3];
int ls[100010 << 3], rs[100010 << 3], root[100010 << 3];
int cnt;
#define mid ((l + r) >> 1)
node pushup(node L, node R)
{
node ret;
ret.sum1 = (L.sum1 + R.sum1) % mod;
ret.sum2 = (L.sum2 + R.sum2) % mod;
return ret;
}
void add(int &now, int l, int r, int x, int k)
{
if (!now)
now = ++cnt;
if (l == r)
{
tree[now].sum1 += k * x % mod;
tree[now].sum2 += k;
tree[now].sum1 %= mod;
tree[now].sum2 %= mod;
return;
}
if (mid >= x)
add(ls[now], l, mid, x, k);
else
add(rs[now], mid + 1, r, x, k);
tree[now] = pushup(tree[ls[now]], tree[rs[now]]);
}
int merge(int a, int b, int l, int r)
{
if (!a || !b)
return a ^ b;
if (l == r)
{
tree[a].sum1 += tree[b].sum1;
tree[a].sum2 += tree[b].sum2;
tree[a].sum1 %= mod;
tree[a].sum2 %= mod;
return a;
}
ls[a] = merge(ls[a], ls[b], l, mid);
rs[a] = merge(rs[a], rs[b], mid + 1, r);
tree[a] = pushup(tree[ls[a]], tree[rs[a]]);
return a;
}
int ask1(int now, int l, int r, int x, int y)
{
if (!now)
return 0;
if (l >= x && r <= y)
{
return tree[now].sum1;
}
int ret = 0;
if (mid >= x)
ret += ask1(ls[now], l, mid, x, y);
if (mid < y)
ret += ask1(rs[now], mid + 1, r, x, y);
return ret % mod;
}
int ask2(int now, int l, int r, int x, int y)
{
if (!now)
return 0;
if (l >= x && r <= y)
{
return tree[now].sum2;
}
int ret = 0;
if (mid >= x)
ret += ask2(ls[now], l, mid, x, y);
if (mid < y)
ret += ask2(rs[now], mid + 1, r, x, y);
return ret % mod;
}
} T0, T1;
void dfs(int now, int fa)
{
dep[now] = dep[fa] + 1;
int sum0 = 0, sum1 = 0;
for (int i = head[now]; i; i = e[i].from)
{
int u = e[i].to;
if (u == fa)
continue;
dfs(u, now);
g[u][0] = T0.ask1(T0.root[u], 1, n, dep[u], min(dep[u] + k - 1, n)) -
dep[now] * T0.ask2(T0.root[u], 1, n, dep[u], min(dep[u] + k - 1, n));
g[u][1] = T1.ask1(T1.root[u], 1, n, dep[u], min(dep[u] + k - 1, n)) -
dep[now] * T1.ask2(T1.root[u], 1, n, dep[u], min(dep[u] + k - 1, n));
g[u][0] = (g[u][0] + mod) % mod;
g[u][1] = (g[u][0] + mod) % mod;
sum0 += g[u][0];
sum0 %= mod;
sum1 += g[u][1];
sum1 %= mod;
h[now][0][0] += h[u][0][0];
h[now][0][0] %= mod;
h[now][0][1] += h[u][0][1];
h[now][0][1] %= mod;
h[now][1][0] += h[u][1][0];
h[now][1][0] %= mod;
h[now][1][1] += h[u][1][1];
h[now][1][1] %= mod;
}
if (col[now] == 1 || col[now] == -1)
{
f[now][1][0] = sum0 + 1;
T1.add(T1.root[now], 1, n, dep[now], f[now][1][0]);
for (int i = head[now]; i; i = e[i].from)
{
int u = e[i].to;
if (u == fa)
continue;
f[now][1][1] += (sum0 - g[u][0]) * g[u][0] % mod;
f[now][1][1] %= mod;
T1.root[now] = T1.merge(T1.root[now], T1.root[u], 1, n);
}
f[now][1][1] = f[now][1][1] * inv2 % mod;
}
if (col[now] == 0 || col[now] == -1)
{
f[now][0][0] = sum1 + 1;
T0.add(T0.root[now], 1, n, dep[now], f[now][0][0]);
for (int i = head[now]; i; i = e[i].from)
{
int u = e[i].to;
if (u == fa)
continue;
f[now][0][1] += (sum1 - g[u][1]) * g[u][1] % mod;
f[now][0][1] %= mod;
T0.root[now] = T0.merge(T0.root[now], T0.root[u], 1, n);
}
f[now][0][1] = f[now][0][1] * inv2 % mod;
}
}
signed main()
{
cin >> n >> k;
for (int i = 1; i <= n; i++)
cin >> col[i];
for (int i = 1; i < n; i++)
{
int x, y;
cin >> x >> y;
addedge(x, y);
addedge(y, x);
}
dfs(1, 1);
int ans = 0;
for (int i = 1; i <= n; i++)
ans = (ans + f[i][0][0] + f[i][0][1] + f[i][1][0] + f[i][1][1]) % mod;
cout << (ans + mod) % mod;
}

浙公网安备 33010602011771号