射手座之日 题目分析
射手座之日 题目分析
题目概述LuoguU95602
给一个 \(1\) 到 \(n\) 的排列 \(a\),并且给出点权 \(x_i\),并定义:
其中 \(lca(x,y)\) 表示 \(x\) 和 \(y\) 的最近公共祖先。
并且给出一颗树。
求:
分析
像这种比较经典的双 sigma 题目,最最最最暴力的解法是 \(\mathcal{O}(n^3)\)(先不考虑这里求 \(LCA\) 的 \(\log\))。
那么很显然,我们可以固定最短点 \(i\),\(j\) 不断地向右扩展,这样就会得到 \(\mathcal{O}(n^2)\) 算法。
于是就很简单地拿到了此题的 \(40\) 分。
那么怎么优化到 \(\mathcal{O}(n\log n)\) 呢?
我一般的思路是直接上线段树。
我们这颗线段树(显然维护的是 \(dfs\) 序区间)维护两个值,一个是 \(cnt\) 代表这段区间内有点作为 \(lca\) 的总方案,\(sum\) 就是加和实际的数量 \(\times x_{lca}\),注意到这里只有在有可能作为 \(lca\) 的点上相乘,pushup 的时候都是加和(这里的思路比较巧妙)。
我们先假设有了一些区间的左端点,然后右端点往扩展(假设从 \(i-1\) 到 \(i\))。
那么就是这样的:
|-----------|--->|
|---------|--->|
|-------|--->|
|---|--->|
i - 1->i
首先对于之前的所有区间,我都是得到了各自中的 \(lca\) 并存储到了线段树的结点上面,那么很显然我们每次扩展一次,就计算一次答案——\(tr[1].sum\)。
然后我们考虑新的贡献:假设 \(p=lca(a_{i-1},a_i).\)
我们发现如果之前有些区间的 \(lca\)(此处假设为 \(p_2\))满足 \(p\) 在 \(p_2\) 到 \(1\) 的路径上面,是不是说明我这些贡献(按道理来说是方案 \(cnt\))是不是得删除并且挪到当前 \(lca\) 也就是 \(p\) 上面。
换个角度想,是不是满足 \(p_2\) 是 \(p\)(包括 \(p\))子树以内的结点就可以挪。
最后,我们把单独一个点 \(a_i\) 的贡献加上就可以了。
代码
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <stdlib.h>
#include <cstring>
#include <vector>
#define N 200005
#define int long long
#define isdigit(ch) ('0' <= ch && ch <= '9')
using namespace std;
template<typename T>
void read(T &x) {
x = 0;
int f = 1;
char ch = getchar();
for (;!isdigit(ch);ch = getchar()) f = (ch == '-' ? -1 : f);
for (;isdigit(ch);ch = getchar()) x = (x << 3) + (x << 1) + (ch ^ 48);
x *= f;
}
template<typename T>
void write(T x) {
if (x < 0) x = -x,putchar('-');
if (x > 9) write(x / 10);
putchar(x % 10 + '0');
}
vector<int> g[N];
int n,fa[N][25],a[N],val[N],dep[N],dfn[N],cnt,st[N][25],sz[N],rid[N];
void dfs0(int cur) {
dfn[cur] = ++cnt;
rid[cnt] = cur;
sz[cur] = 1;
st[cnt][0] = fa[cur][0];
dep[cur] = dep[fa[cur][0]] + 1;
for (auto i : g[cur])
if (i != fa[cur][0])
dfs0(i),sz[cur] += sz[i];
}
// ----------------------------- 倍增求LCA ------------------------------------
int LCA(int x,int y) {
if (dep[x] < dep[y]) x ^= y ^= x ^= y;
for (int j = 20;j >= 0;j --)
if (dep[fa[x][j]] >= dep[y]) x = fa[x][j];
if (x == y) return x;
for (int j = 20;j >= 0;j --)
if (fa[x][j] != fa[y][j]) x = fa[x][j],y = fa[y][j];
return fa[x][0];
}
// ----------------------------- dfs序求LCA ----------------------------------
int GET(int x) {
int len = 0,p = 0;
for (;x;x >>= 1,len ++) p = (x & 1 ? len : p);
return p;
}
int get(int x,int y) {
return dfn[x] < dfn[y] ? x : y;
}
int getlca(int x,int y) {
if (x == y) return x;
if ((x = dfn[x]) > (y = dfn[y])) x ^= y ^= x ^= y;
int t = GET(y - x);
return get(st[x + 1][t],st[y - (1 << t) + 1][t]);
}
// ----------------------------- segment tree ---------------------------------
#define ls(x) (x << 1)
#define rs(x) (x << 1 | 1)
struct node{
int cnt;
int sum;
}tr[N << 2];
int lz[N << 2];
void pushup(int x) {
tr[x].cnt = tr[ls(x)].cnt + tr[rs(x)].cnt;
tr[x].sum = tr[ls(x)].sum + tr[rs(x)].sum;
}
void pushdown(int x) {
lz[ls(x)] = lz[rs(x)] = -1;
tr[ls(x)] = tr[rs(x)] = {0,0};
lz[x] = 0;
}
void update(int x,int l,int r,int pos,int value) {
if (l == r) {
tr[x].cnt += value;
tr[x].sum += value * val[rid[l]];
return;
}
if (lz[x] == -1) pushdown(x);
int mid = l + r >> 1;
if (pos <= mid) update(ls(x),l,mid,pos,value);
else update(rs(x),mid + 1,r,pos,value);
pushup(x);
}
int query(int x,int l,int r,int L,int R) {
if (l > R || r < L) return 0;
if (L <= l && r <= R) {
int p = tr[x].cnt;
tr[x].cnt = 0,tr[x].sum = 0;
lz[x] = -1;
return p;
}
if (lz[x] == -1) pushdown(x);
int mid = l + r >> 1,ans = query(ls(x),l,mid,L,R) + query(rs(x),mid + 1,r,L,R);
pushup(x);
return ans;
}
signed main(){
read(n);
for (int i = 2;i <= n;i ++) read(fa[i][0]),g[fa[i][0]].push_back(i);
for (int i = 1;i <= n;i ++) read(a[i]);
for (int i = 1;i <= n;i ++) read(val[i]);
dfs0(1);
for (int j = 1;j <= 20;j ++)
for (int i = 1;i <= n;i ++)
fa[i][j] = fa[fa[i][j - 1]][j - 1];
for (int j = 1;j <= 20;j ++)
for (int i = 1;i + (1 << j) - 1 <= n;i ++)
st[i][j] = get(st[i][j - 1],st[i + (1 << j - 1)][j - 1]);
int ans = 0;
for (int i = 1;i <= n;i ++) {
if (i > 1) {
int t = getlca(a[i],a[i - 1]);
int tot = query(1,1,n,dfn[t],dfn[t] + sz[t] - 1);
update(1,1,n,dfn[t],tot);
}
update(1,1,n,dfn[a[i]],1);
ans += tr[1].sum;
}
write(ans);
return 0;
}
扩展分析——常数太大?
数据范围其实给了我们提示:
对于另外20%的数据,排列 ai 是用如下的算法生成的:从一号点始对树做 dfs,到达一个节点的时候输出这个节点。
此时我们分析道:任意一段 \(a\) 相对于 \(dfs\) 序是一个连续的区间。
启发我们用区间合并的思路。
我们首先得到了一个比较显而易得的结论:
对于一些 \(lca\) 在当前结点 \(i\) 的子树中是由一段又一段的 \(a\) 组成的。
然后不难得出:
设 \(rk_{a_i}=i\),只要选的 \(a\) 数值是连续的,那么 \(rk\) 相对应的部分也是连续的。
中国有句古话:麻雀虽小,五脏俱全。
别看这个小小结论,却能引出这道题的另一个算法。
考虑 \(dfs\) 这整棵树,但时间太大,不可接受。
但既然必须要 \(dfs\) 了,那就用启发式合并。
期望时间复杂度为 \(\mathcal{O}(n\log n).\)
如何计算结点 \(i\) 的子树中一些点作为 \(lca\) 方案总和呢?
考虑这样的树:

我们设 \(len_i\) 表示现在 \(i\) 作为某颗子树的一段连续区间(指在 \(a\) 上)的左端点或者右端点所得到的长度。
如果对应现在的 \(p = i\) 为根的子树:\(p\) 必选,显然只有左边和右边的情况,即 \(len_{p-1}\) 和 \(len_{p+1}\)。
所得到了总长 \(length=len_{p-1}+len_{p+1}+1.\)
那么总方案是 \(length(length-1)\div 2.\)
考虑到不小心把左右边单独的方案也算进去了,所以减去,因此得到下面的代码:
void getans(int p) {
int lenl = len[p - 1],lenr = len[p + 1];
int length = lenl + lenr + 1;
len[p - len[p - 1]] = len[p + len[p + 1]] = length;
cnt += length * (length - 1) / 2 - lenl * (lenl - 1) / 2 - lenr * (lenr - 1) / 2;
}
最后算答案。
设 \(ans_x\) 表示以 \(x\) 为根的子树(中一些 \(lca\))所有方案。
那我们单独算 \(x\) 作为 \(lca\) 的方案就为 \(cnt-\sum_{j\in son_x}ans_j\),其中 \(cnt\) 为刚刚算出的方案(上面代码的)。
代码
#include <iostream>
#include <cstdio>
#include <stdlib.h>
#include <cstring>
#include <algorithm>
#include <vector>
#define int long long
#define N 200005
using namespace std;
int n,fa[N],a[N],val[N],rk[N],len[N],sz[N],son[N],dep[N],ans[N],res,Son,cnt,sum;
vector<int> g[N];
void dfs0(int cur) {
dep[cur] = dep[fa[cur]] + 1;
sz[cur] = 1;
for (auto i : g[cur])
if (i != fa[cur]) {
dfs0(i);
sz[cur] += sz[i];
if (sz[son[cur]] < sz[i]) son[cur] = i;
}
}
void getans(int p) {
int lenl = len[p - 1],lenr = len[p + 1];
int length = lenl + lenr + 1;
len[p - len[p - 1]] = len[p + len[p + 1]] = length;
cnt += length * (length - 1) / 2 - lenl * (lenl - 1) / 2 - lenr * (lenr - 1) / 2;
}
void gettree(int cur) {
getans(rk[cur]);
for (auto i : g[cur])
if (i != fa[cur] && i != Son)
gettree(i);
}
void clear(int cur) {
len[rk[cur]] = 0;
for (auto i : g[cur])
if (i != fa[cur]) clear(i);
}
void dfs1(int cur,bool opt) {
for (auto i : g[cur])
if (i != fa[cur] && i != son[cur])
dfs1(i,1);
if (son[cur]) dfs1(son[cur],0),Son = son[cur];
gettree(cur),Son = 0;
int now = (ans[cur] = cnt);
for (auto i : g[cur]) now -= ans[i];
res += now * val[cur];
if (opt) clear(cur),cnt = 0;
}
signed main() {
cin >> n;
for (int i = 2;i <= n;i ++) cin >> fa[i],g[fa[i]].push_back(i);
for (int i = 1;i <= n;i ++) cin >> a[i],rk[a[i]] = i;
for (int i = 1;i <= n;i ++) cin >> val[i],sum += val[i];
dfs0(1),dfs1(1,0);
printf("%lld\n",res + sum);
return 0;
}

浙公网安备 33010602011771号