把博客园图标替换成自己的图标
把博客园图标替换成自己的图标end

射手座之日 题目分析

射手座之日 题目分析

题目概述LuoguU95602

给一个 \(1\)\(n\) 的排列 \(a\),并且给出点权 \(x_i\),并定义:

\[LCA\{b\}=lca(lca(\dots(b_1,b_2),\dots),b_m) \]

其中 \(lca(x,y)\) 表示 \(x\)\(y\) 的最近公共祖先。

并且给出一颗树。

求:

\[ans=\sum_{i=1}^n\sum_{j=i+1}^nx_{LCA_{i\in[i,j]}\{a_i\}} \]

分析

像这种比较经典的双 sigma 题目,最最最最暴力的解法是 \(\mathcal{O}(n^3)\)(先不考虑这里求 \(LCA\)\(\log\))。

那么很显然,我们可以固定最短点 \(i\)\(j\) 不断地向右扩展,这样就会得到 \(\mathcal{O}(n^2)\) 算法。

于是就很简单地拿到了此题的 \(40\) 分。

那么怎么优化到 \(\mathcal{O}(n\log n)\) 呢?

我一般的思路是直接上线段树。

我们这颗线段树(显然维护的是 \(dfs\) 序区间)维护两个值,一个是 \(cnt\) 代表这段区间内有点作为 \(lca\) 的总方案,\(sum\) 就是加和实际的数量 \(\times x_{lca}\),注意到这里只有在有可能作为 \(lca\) 的点上相乘,pushup 的时候都是加和(这里的思路比较巧妙)。

我们先假设有了一些区间的左端点,然后右端点往扩展(假设从 \(i-1\)\(i\))。

那么就是这样的:

|-----------|--->|
  |---------|--->|
  	|-------|--->|
  		|---|--->|
  		  i - 1->i

首先对于之前的所有区间,我都是得到了各自中的 \(lca\) 并存储到了线段树的结点上面,那么很显然我们每次扩展一次,就计算一次答案——\(tr[1].sum\)

然后我们考虑新的贡献:假设 \(p=lca(a_{i-1},a_i).\)

我们发现如果之前有些区间的 \(lca\)(此处假设为 \(p_2\))满足 \(p\)\(p_2\)\(1\) 的路径上面,是不是说明我这些贡献(按道理来说是方案 \(cnt\))是不是得删除并且挪到当前 \(lca\) 也就是 \(p\) 上面。

换个角度想,是不是满足 \(p_2\)\(p\)(包括 \(p\))子树以内的结点就可以挪。

最后,我们把单独一个点 \(a_i\) 的贡献加上就可以了。

代码

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <stdlib.h>
#include <cstring>
#include <vector>
#define N 200005
#define int long long
#define isdigit(ch) ('0' <= ch && ch <= '9')
using namespace std;
template<typename T>
void read(T &x) {
	x = 0;
	int f = 1;
	char ch = getchar();
	for (;!isdigit(ch);ch = getchar()) f = (ch == '-' ? -1 : f);
	for (;isdigit(ch);ch = getchar()) x = (x << 3) + (x << 1) + (ch ^ 48);
	x *= f;
}
template<typename T>
void write(T x) {
	if (x < 0) x = -x,putchar('-');
	if (x > 9) write(x / 10);
	putchar(x % 10 + '0');
}
vector<int> g[N];
int n,fa[N][25],a[N],val[N],dep[N],dfn[N],cnt,st[N][25],sz[N],rid[N];
void dfs0(int cur) {
	dfn[cur] = ++cnt;
	rid[cnt] = cur;
	sz[cur] = 1;
	st[cnt][0] = fa[cur][0];
	dep[cur] = dep[fa[cur][0]] + 1;
	for (auto i : g[cur])
		if (i != fa[cur][0])
			dfs0(i),sz[cur] += sz[i];
}
// ----------------------------- 倍增求LCA ------------------------------------ 
int LCA(int x,int y) {
	if (dep[x] < dep[y]) x ^= y ^= x ^= y;
	for (int j = 20;j >= 0;j --)
		if (dep[fa[x][j]] >= dep[y]) x = fa[x][j];
	if (x == y) return x;
	for (int j = 20;j >= 0;j --)
		if (fa[x][j] != fa[y][j]) x = fa[x][j],y = fa[y][j];
	return fa[x][0]; 
}
// ----------------------------- dfs序求LCA ---------------------------------- 
int GET(int x) {
	int len = 0,p = 0;
	for (;x;x >>= 1,len ++) p = (x & 1 ? len : p);
	return p;
}
int get(int x,int y) {
	return dfn[x] < dfn[y] ? x : y;
}
int getlca(int x,int y) {
	if (x == y) return x;
	if ((x = dfn[x]) > (y = dfn[y])) x ^= y ^= x ^= y;
	int t = GET(y - x);
	return get(st[x + 1][t],st[y - (1 << t) + 1][t]);
} 
// ----------------------------- segment tree ---------------------------------
#define ls(x) (x << 1)
#define rs(x) (x << 1 | 1)
struct node{
	int cnt;
	int sum;
}tr[N << 2];
int lz[N << 2];
void pushup(int x) {
	tr[x].cnt = tr[ls(x)].cnt + tr[rs(x)].cnt;
	tr[x].sum = tr[ls(x)].sum + tr[rs(x)].sum;
}
void pushdown(int x) {
	lz[ls(x)] = lz[rs(x)] = -1;
	tr[ls(x)] = tr[rs(x)] = {0,0};
	lz[x] = 0;
}
void update(int x,int l,int r,int pos,int value) {
	if (l == r) {
		tr[x].cnt += value;
		tr[x].sum += value * val[rid[l]];
		return;
	}
	if (lz[x] == -1) pushdown(x);
	int mid = l + r >> 1;
	if (pos <= mid) update(ls(x),l,mid,pos,value);
	else update(rs(x),mid + 1,r,pos,value);
	pushup(x);
}
int query(int x,int l,int r,int L,int R) {
	if (l > R || r < L) return 0;
	if (L <= l && r <= R) {
		int p = tr[x].cnt;
		tr[x].cnt = 0,tr[x].sum = 0;
		lz[x] = -1;
		return p;
	}
	if (lz[x] == -1) pushdown(x);
	int mid = l + r >> 1,ans = query(ls(x),l,mid,L,R) + query(rs(x),mid + 1,r,L,R);
	pushup(x);
	return ans;
}
signed main(){
	read(n);
	for (int i = 2;i <= n;i ++) read(fa[i][0]),g[fa[i][0]].push_back(i);
	for (int i = 1;i <= n;i ++) read(a[i]);
	for (int i = 1;i <= n;i ++) read(val[i]);
	dfs0(1);
	for (int j = 1;j <= 20;j ++)
		for (int i = 1;i <= n;i ++)
			fa[i][j] = fa[fa[i][j - 1]][j - 1];
	for (int j = 1;j <= 20;j ++)
		for (int i = 1;i + (1 << j) - 1 <= n;i ++)
			st[i][j] = get(st[i][j - 1],st[i + (1 << j - 1)][j - 1]);
	int ans = 0;
	for (int i = 1;i <= n;i ++) {
		if (i > 1) {
			int t = getlca(a[i],a[i - 1]);
			int tot = query(1,1,n,dfn[t],dfn[t] + sz[t] - 1);
			update(1,1,n,dfn[t],tot);
		}
		update(1,1,n,dfn[a[i]],1);
		ans += tr[1].sum;
	}
	write(ans); 
	return 0;
} 

扩展分析——常数太大?

数据范围其实给了我们提示:

对于另外20%的数据,排列 ai 是用如下的算法生成的:从一号点始对树做 dfs,到达一个节点的时候输出这个节点。

此时我们分析道:任意一段 \(a\) 相对于 \(dfs\) 序是一个连续的区间。

启发我们用区间合并的思路。

我们首先得到了一个比较显而易得的结论:

对于一些 \(lca\) 在当前结点 \(i\) 的子树中是由一段又一段的 \(a\) 组成的。

然后不难得出:

\(rk_{a_i}=i\),只要选的 \(a\) 数值是连续的,那么 \(rk\) 相对应的部分也是连续的。

中国有句古话:麻雀虽小,五脏俱全

别看这个小小结论,却能引出这道题的另一个算法。

考虑 \(dfs\) 这整棵树,但时间太大,不可接受。

但既然必须要 \(dfs\) 了,那就用启发式合并。

期望时间复杂度为 \(\mathcal{O}(n\log n).\)

如何计算结点 \(i\) 的子树中一些点作为 \(lca\) 方案总和呢?

考虑这样的树:

就是一个i和一大堆子树

我们设 \(len_i\) 表示现在 \(i\) 作为某颗子树的一段连续区间(指在 \(a\) 上)的左端点或者右端点所得到的长度。

如果对应现在的 \(p = i\) 为根的子树:\(p\) 必选,显然只有左边和右边的情况,即 \(len_{p-1}\)\(len_{p+1}\)

所得到了总长 \(length=len_{p-1}+len_{p+1}+1.\)

那么总方案是 \(length(length-1)\div 2.\)

考虑到不小心把左右边单独的方案也算进去了,所以减去,因此得到下面的代码:

void getans(int p) {
	int lenl = len[p - 1],lenr = len[p + 1];
	int length = lenl + lenr + 1;
	len[p - len[p - 1]] = len[p + len[p + 1]] = length;
	cnt += length * (length - 1) / 2 - lenl * (lenl - 1) / 2 - lenr * (lenr - 1) / 2;
}

最后算答案。

\(ans_x\) 表示以 \(x\) 为根的子树(中一些 \(lca\))所有方案。

那我们单独算 \(x\) 作为 \(lca\) 的方案就为 \(cnt-\sum_{j\in son_x}ans_j\),其中 \(cnt\) 为刚刚算出的方案(上面代码的)。

代码

#include <iostream>
#include <cstdio>
#include <stdlib.h>
#include <cstring>
#include <algorithm>
#include <vector>
#define int long long
#define N 200005
using namespace std;
int n,fa[N],a[N],val[N],rk[N],len[N],sz[N],son[N],dep[N],ans[N],res,Son,cnt,sum;
vector<int> g[N];
void dfs0(int cur) {
	dep[cur] = dep[fa[cur]] + 1;
	sz[cur] = 1;
	for (auto i : g[cur])
		if (i != fa[cur]) {
			dfs0(i);
			sz[cur] += sz[i];
			if (sz[son[cur]] < sz[i]) son[cur] = i;
		}
}
void getans(int p) {
	int lenl = len[p - 1],lenr = len[p + 1];
	int length = lenl + lenr + 1;
	len[p - len[p - 1]] = len[p + len[p + 1]] = length;
	cnt += length * (length - 1) / 2 - lenl * (lenl - 1) / 2 - lenr * (lenr - 1) / 2;
}
void gettree(int cur) {
	getans(rk[cur]);
	for (auto i : g[cur])
		if (i != fa[cur] && i != Son)
			gettree(i); 
}
void clear(int cur) {
	len[rk[cur]] = 0;
	for (auto i : g[cur])
		if (i != fa[cur]) clear(i);
}
void dfs1(int cur,bool opt) {
	for (auto i : g[cur])
		if (i != fa[cur] && i != son[cur])
			dfs1(i,1);
	if (son[cur]) dfs1(son[cur],0),Son = son[cur];
	gettree(cur),Son = 0;
	int now = (ans[cur] = cnt);
	for (auto i : g[cur]) now -= ans[i];
	res += now * val[cur];
	if (opt) clear(cur),cnt = 0;
}
signed main() {
	cin >> n;
	for (int i = 2;i <= n;i ++) cin >> fa[i],g[fa[i]].push_back(i);
	for (int i = 1;i <= n;i ++) cin >> a[i],rk[a[i]] = i;
	for (int i = 1;i <= n;i ++) cin >> val[i],sum += val[i];
	dfs0(1),dfs1(1,0);
	printf("%lld\n",res + sum);
	return 0;
}
posted @ 2025-02-11 11:18  high_skyy  阅读(13)  评论(0)    收藏  举报
浏览器标题切换
浏览器标题切换end