Jan. 9th 2026

礼物

Solution 1

View Code
#include <bits/stdc++.h>
using namespace std;

using lf = double;
using ll = long long;
using ull = unsigned long long;
using pii = pair<int, int>;

const int maxn = 1e5 + 5;

int wealth[maxn];
int n, m, root = 1, p; // about the question

namespace gift{
	struct edge{
		int wealth;
		int id;
	} input[maxn];
	
	struct query{
		int start, end;
		int inf;
		int id;
		ll ans;
	} output[maxn << 1];
	
	bool cmp_e(const edge &a, const edge &b) {
		return a.wealth < b.wealth;
	}
	
	bool cmp_q_w(const query &a, const query &b) {
		return a.inf < b.inf;
	}
	
	bool cmp_q_id(const query &a, const query &b) {
		return a.id < b.id;
	}
};

namespace TREE_CHAIN {
	
	vector<int> adj[maxn];
	int depth[maxn], father[maxn], size[maxn];
	
	int dfn[maxn], out[maxn], dfs[maxn], cnt;
	
	int heavy_son[maxn], top[maxn]; // tree chain
	
	struct segment_tree{
		ll tree[maxn << 2];
		ll lazy_tag[maxn << 2];
		
		int ls(int id) {return id << 1;}
		int rs(int id) {return id << 1 | 1;}
		
		void maintain(int id) {
			return tree[id] = tree[ls(id)] + tree[rs(id)], void();
		}
		
		void build(int id, int left, int right) {
			lazy_tag[id] = 0;
			if (left == right) {
				tree[id] = gift::input[dfs[left]].wealth;
				return ;
			}
			int mid = (left + right) >> 1;
			build(ls(id), left, mid);
			build(rs(id), mid + 1, right);
			maintain(id);
		}
		
		void addtag(ll d, int id, int left, int right) {
			lazy_tag[id] += d;
			tree[id] = tree[id] + d * (right - left + 1);
		}
		
		void pushdown(int id, int left, int right) {
			if (lazy_tag[id]) {
				ll mid = (left + right) >> 1;
				addtag(lazy_tag[id], ls(id), left, mid);
				addtag(lazy_tag[id], rs(id), mid + 1, right);
				lazy_tag[id] = 0;
			}
		}
		
		void update(int L, int R, ll d, int id = 1, int left = 1, int right = n) {
			if (L <= left and right <= R) {
				addtag(d, id, left, right);
				return;
			}
			pushdown(id, left, right);
			ll mid = (left + right) >> 1;
			if (L <= mid) update(L, R, d, ls(id), left, mid);
			if (R > mid) update(L, R, d, rs(id), mid + 1, right);
			maintain(id);
		}
		
		ll query(int L, int R, int id = 1, int left = 1, int right = n) {
			if (L <= left and right <= R)
				return tree[id];
			pushdown(id, left, right);
			ll mid = (left + right) >> 1;
			ll res = 0;
			if (L <= mid) res = res + query(L, R, ls(id), left, mid);
			if (R > mid) res = res + query(L, R, rs(id), mid + 1, right);
			return res;
		}
		
	} st;
	
	void all_init() {
		st.build(1, 1, n);
	}
	
	void init(int u, int fath, int dep) {
		size[u] = 1;
		father[u] = fath;
		depth[u] = dep;
		for (int v : adj[u]) {
			if (v == fath) continue;
			init(v, u, dep + 1);
			size[u] += size[v];
			if (size[v] > size[heavy_son[u]]) heavy_son[u] = v;
		}
	}
	
	void another_init(int u, int fath, int head) {
		top[u] = head;
		dfn[u] = ++cnt;
		dfs[cnt] = u;
		if (heavy_son[u]) another_init(heavy_son[u], u, head);
		for (int v : adj[u]) {
			if (v == fath or v == heavy_son[u]) continue;
			another_init(v, u, v);
		}
		out[u] = cnt;
	}
	
	void chain_update(int x, int y, ll delta) {
		while (top[x] != top[y]) {
			if (depth[top[x]] < depth[top[y]]) swap(x, y);
			st.update(dfn[top[x]], dfn[x], delta);// segment tree
			x = father[top[x]];
		}
		
		if (depth[x] > depth[y]) swap(x, y);
		st.update(dfn[x], dfn[y], delta);
	}
	
	ll chain_query(int x, int y) {
		ll res = 0;
		while (top[x] != top[y]) {
			if (depth[top[x]] < depth[top[y]]) swap(x, y);
			res = res + st.query(dfn[top[x]], dfn[x]);// segment tree
			x = father[top[x]];
		}
		
		if (depth[x] > depth[y]) swap(x, y);
		res = res + st.query(dfn[x], dfn[y]);
		return res;
	}
	
	void reset_all() {
		cnt = 0;
		memset(depth, 0, sizeof(depth));
		memset(father, 0, sizeof(father));
		memset(size, 0, sizeof(size));
		memset(dfn, 0, sizeof(dfn));
		memset(out, 0, sizeof(out));
		memset(dfs, 0, sizeof(dfs));
		memset(heavy_son, 0, sizeof(heavy_son));
		memset(top, 0, sizeof(top));
		for (int i = 0; i < maxn; i++) adj[i].clear();
	}
}

using namespace gift;
using namespace TREE_CHAIN;

void solve() {
//	cin >> n >> m;
	
	reset_all();
	
	for (int i = 1; i <= n; i++) {
		cin >> input[i].wealth;
		input[i].id = i;
	}
	
	for (int i = 1; i <= n - 1; i++) {
		int u, v;
		cin >> u >> v;
		adj[u].push_back(v);
		adj[v].push_back(u);
	}
	
	init(root, 0, 1);
	another_init(root, 0, root);
	all_init();
	
	for (int i = 1; i <= m; i++) {
		int s, t, l, r;
		cin >> s >> t >> l >> r;
		output[2 * i - 1] = {s, t, l - 1, 2 * i - 1, 0};
		output[2 * i] = {s, t, r, 2 * i, 0};
	}
	
	sort(input + 1, input + n + 1, cmp_e);
	sort(output + 1, output + 2 * m + 1, cmp_q_w);
	
	int j = 1;
	for (int i = 1; i <= 2 * m; i++) {
		while(input[j].wealth <= output[i].inf and j <= n){
			chain_update(input[j].id, input[j].id, input[j].wealth);
			j++;
		}
		output[i].ans = chain_query(output[i].start, output[i].end);
	}
	
	sort(output + 1, output + 2 * m + 1, cmp_q_id);
	
	for (int i = 1; i <= m; i++) cout << output[2 * i].ans - output[2 * i - 1].ans << " ";
	cout << "\n";
	
}

int main() {
	
	ios::sync_with_stdio(false);
	cin.tie(nullptr);
	cout.tie(nullptr);
	
//	freopen("t1.in","r",stdin);
//	freopen("t1.out","w",stdout);
	
	int T = 1;
//	cin >> T;
	while (cin >> n >> m) solve();
	
	return 0;
}

Solution 2

维护数组

View Code
#include<bits/stdc++.h>
using namespace std;

using ll = long long;
const int maxn = 1e5 + 5;

int wealth[maxn];
int n, m, root = 1;

namespace TREE_CHAIN {
	vector<int> adj[maxn];
	int depth[maxn], father[maxn], size[maxn];
	int dfn[maxn], out[maxn], dfs[maxn], cnt;
	int heavy_son[maxn], top[maxn];
	
	struct sgt_node {
		vector<int> array;
		vector<ll> prefixsum;
		sgt_node() = default;
		sgt_node(vector<int> arr, vector<ll> pre) : array(arr), prefixsum(pre) {}
		
		ll ans(int st, int ed) {
			if (array.empty() || ed < st) return 0;
			int left = lower_bound(array.begin(), array.end(), st) - array.begin();
			int right = upper_bound(array.begin(), array.end(), ed) - array.begin();
			ll sum_left = (left == 0) ? 0 : prefixsum[left - 1];
			ll sum_right = (right == 0) ? 0 : prefixsum[right - 1];
			return sum_right - sum_left;
		}
	};
	
	sgt_node operator+(const sgt_node &a, const sgt_node &b) {
		sgt_node res;
		res.array.resize(a.array.size() + b.array.size());
		merge(a.array.begin(), a.array.end(), b.array.begin(), b.array.end(), res.array.begin());
		res.prefixsum.resize(res.array.size(), 0);
		if (!res.array.empty()) {
			res.prefixsum[0] = res.array[0];
			for (int i = 1; i < (int)res.prefixsum.size(); i++) {
				res.prefixsum[i] = res.prefixsum[i - 1] + res.array[i];
			}
		}
		return res;
	}
	
	struct segment_tree {
		sgt_node tree[maxn << 2];
		
		int ls(int id) { return id << 1; }
		int rs(int id) { return id << 1 | 1; }
		
		void maintain(int id) {
			tree[id] = tree[ls(id)] + tree[rs(id)];
		}
		
		void build(int id, int left, int right) {
			if (left == right) {
				int val = wealth[dfs[left]];
				tree[id].array = {val};
				tree[id].prefixsum = {val};
				return;
			}
			int mid = (left + right) >> 1;
			build(ls(id), left, mid);
			build(rs(id), mid + 1, right);
			maintain(id);
		}
		
		sgt_node query(int L, int R, int id = 1, int left = 1, int right = n) {
			if (L <= left && right <= R)
				return tree[id];
			int mid = (left + right) >> 1;
			sgt_node res;
			if (L <= mid) res = res + query(L, R, ls(id), left, mid);
			if (R > mid)  res = res + query(L, R, rs(id), mid + 1, right);
			return res;
		}
	} st;
	
	void all_init() {
		st.build(1, 1, n);
	}
	
	void init(int u, int fath, int dep) {
		size[u] = 1;
		father[u] = fath;
		depth[u] = dep;
		heavy_son[u] = 0;
		for (int v : adj[u]) {
			if (v == fath) continue;
			init(v, u, dep + 1);
			size[u] += size[v];
			if (size[v] > size[heavy_son[u]])
				heavy_son[u] = v;
		}
	}
	
	void another_init(int u, int fath, int head) {
		top[u] = head;
		cnt++;
		dfn[u] = cnt;
		dfs[cnt] = u;
		if (heavy_son[u]) another_init(heavy_son[u], u, head);
		for (int v : adj[u]) {
			if (v == fath || v == heavy_son[u]) continue;
			another_init(v, u, v);
		}
		out[u] = cnt;
	}
	
	ll chain_query(int x, int y, int start, int end) {
		ll res = 0;
		while (top[x] != top[y]) {
			if (depth[top[x]] < depth[top[y]]) swap(x, y);
			res += st.query(dfn[top[x]], dfn[x]).ans(start, end);
			x = father[top[x]];
		}
		if (depth[x] > depth[y]) swap(x, y);
		res += st.query(dfn[x], dfn[y]).ans(start, end);
		return res;
	}
}

using namespace TREE_CHAIN;

void solve() {
//	cin >> n >> m;
	
	for (int i = 1; i <= n; i++)
		cin >> wealth[i];
	
	for (int i = 1; i < n; i++) {
		int u, v;
		cin >> u >> v;
		adj[u].push_back(v);
		adj[v].push_back(u);
	}
	
	cnt = 0;
	
	init(root, 0, 1);
	another_init(root, 0, root);
	all_init();
	
	for (int i = 1; i <= m; i++) {
		int s, t, a, b;
		cin >> s >> t >> a >> b;
		cout << chain_query(s, t, a, b) << " ";
	}
	cout << "\n";
}

int main() {
	ios::sync_with_stdio(false);
	cin.tie(nullptr);
	cout.tie(nullptr);
	
	int T = 1;
	while (cin >> n >> m) solve();
	return 0;
}
posted @ 2026-01-09 21:22  Yangyihao  阅读(2)  评论(0)    收藏  举报