Jan. 9th 2026

礼物
Solution 1
View Code
#include <bits/stdc++.h>
using namespace std;
using lf = double;
using ll = long long;
using ull = unsigned long long;
using pii = pair<int, int>;
const int maxn = 1e5 + 5;
int wealth[maxn];
int n, m, root = 1, p; // about the question
namespace gift{
struct edge{
int wealth;
int id;
} input[maxn];
struct query{
int start, end;
int inf;
int id;
ll ans;
} output[maxn << 1];
bool cmp_e(const edge &a, const edge &b) {
return a.wealth < b.wealth;
}
bool cmp_q_w(const query &a, const query &b) {
return a.inf < b.inf;
}
bool cmp_q_id(const query &a, const query &b) {
return a.id < b.id;
}
};
namespace TREE_CHAIN {
vector<int> adj[maxn];
int depth[maxn], father[maxn], size[maxn];
int dfn[maxn], out[maxn], dfs[maxn], cnt;
int heavy_son[maxn], top[maxn]; // tree chain
struct segment_tree{
ll tree[maxn << 2];
ll lazy_tag[maxn << 2];
int ls(int id) {return id << 1;}
int rs(int id) {return id << 1 | 1;}
void maintain(int id) {
return tree[id] = tree[ls(id)] + tree[rs(id)], void();
}
void build(int id, int left, int right) {
lazy_tag[id] = 0;
if (left == right) {
tree[id] = gift::input[dfs[left]].wealth;
return ;
}
int mid = (left + right) >> 1;
build(ls(id), left, mid);
build(rs(id), mid + 1, right);
maintain(id);
}
void addtag(ll d, int id, int left, int right) {
lazy_tag[id] += d;
tree[id] = tree[id] + d * (right - left + 1);
}
void pushdown(int id, int left, int right) {
if (lazy_tag[id]) {
ll mid = (left + right) >> 1;
addtag(lazy_tag[id], ls(id), left, mid);
addtag(lazy_tag[id], rs(id), mid + 1, right);
lazy_tag[id] = 0;
}
}
void update(int L, int R, ll d, int id = 1, int left = 1, int right = n) {
if (L <= left and right <= R) {
addtag(d, id, left, right);
return;
}
pushdown(id, left, right);
ll mid = (left + right) >> 1;
if (L <= mid) update(L, R, d, ls(id), left, mid);
if (R > mid) update(L, R, d, rs(id), mid + 1, right);
maintain(id);
}
ll query(int L, int R, int id = 1, int left = 1, int right = n) {
if (L <= left and right <= R)
return tree[id];
pushdown(id, left, right);
ll mid = (left + right) >> 1;
ll res = 0;
if (L <= mid) res = res + query(L, R, ls(id), left, mid);
if (R > mid) res = res + query(L, R, rs(id), mid + 1, right);
return res;
}
} st;
void all_init() {
st.build(1, 1, n);
}
void init(int u, int fath, int dep) {
size[u] = 1;
father[u] = fath;
depth[u] = dep;
for (int v : adj[u]) {
if (v == fath) continue;
init(v, u, dep + 1);
size[u] += size[v];
if (size[v] > size[heavy_son[u]]) heavy_son[u] = v;
}
}
void another_init(int u, int fath, int head) {
top[u] = head;
dfn[u] = ++cnt;
dfs[cnt] = u;
if (heavy_son[u]) another_init(heavy_son[u], u, head);
for (int v : adj[u]) {
if (v == fath or v == heavy_son[u]) continue;
another_init(v, u, v);
}
out[u] = cnt;
}
void chain_update(int x, int y, ll delta) {
while (top[x] != top[y]) {
if (depth[top[x]] < depth[top[y]]) swap(x, y);
st.update(dfn[top[x]], dfn[x], delta);// segment tree
x = father[top[x]];
}
if (depth[x] > depth[y]) swap(x, y);
st.update(dfn[x], dfn[y], delta);
}
ll chain_query(int x, int y) {
ll res = 0;
while (top[x] != top[y]) {
if (depth[top[x]] < depth[top[y]]) swap(x, y);
res = res + st.query(dfn[top[x]], dfn[x]);// segment tree
x = father[top[x]];
}
if (depth[x] > depth[y]) swap(x, y);
res = res + st.query(dfn[x], dfn[y]);
return res;
}
void reset_all() {
cnt = 0;
memset(depth, 0, sizeof(depth));
memset(father, 0, sizeof(father));
memset(size, 0, sizeof(size));
memset(dfn, 0, sizeof(dfn));
memset(out, 0, sizeof(out));
memset(dfs, 0, sizeof(dfs));
memset(heavy_son, 0, sizeof(heavy_son));
memset(top, 0, sizeof(top));
for (int i = 0; i < maxn; i++) adj[i].clear();
}
}
using namespace gift;
using namespace TREE_CHAIN;
void solve() {
// cin >> n >> m;
reset_all();
for (int i = 1; i <= n; i++) {
cin >> input[i].wealth;
input[i].id = i;
}
for (int i = 1; i <= n - 1; i++) {
int u, v;
cin >> u >> v;
adj[u].push_back(v);
adj[v].push_back(u);
}
init(root, 0, 1);
another_init(root, 0, root);
all_init();
for (int i = 1; i <= m; i++) {
int s, t, l, r;
cin >> s >> t >> l >> r;
output[2 * i - 1] = {s, t, l - 1, 2 * i - 1, 0};
output[2 * i] = {s, t, r, 2 * i, 0};
}
sort(input + 1, input + n + 1, cmp_e);
sort(output + 1, output + 2 * m + 1, cmp_q_w);
int j = 1;
for (int i = 1; i <= 2 * m; i++) {
while(input[j].wealth <= output[i].inf and j <= n){
chain_update(input[j].id, input[j].id, input[j].wealth);
j++;
}
output[i].ans = chain_query(output[i].start, output[i].end);
}
sort(output + 1, output + 2 * m + 1, cmp_q_id);
for (int i = 1; i <= m; i++) cout << output[2 * i].ans - output[2 * i - 1].ans << " ";
cout << "\n";
}
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
cout.tie(nullptr);
// freopen("t1.in","r",stdin);
// freopen("t1.out","w",stdout);
int T = 1;
// cin >> T;
while (cin >> n >> m) solve();
return 0;
}
Solution 2
维护数组
View Code
#include<bits/stdc++.h>
using namespace std;
using ll = long long;
const int maxn = 1e5 + 5;
int wealth[maxn];
int n, m, root = 1;
namespace TREE_CHAIN {
vector<int> adj[maxn];
int depth[maxn], father[maxn], size[maxn];
int dfn[maxn], out[maxn], dfs[maxn], cnt;
int heavy_son[maxn], top[maxn];
struct sgt_node {
vector<int> array;
vector<ll> prefixsum;
sgt_node() = default;
sgt_node(vector<int> arr, vector<ll> pre) : array(arr), prefixsum(pre) {}
ll ans(int st, int ed) {
if (array.empty() || ed < st) return 0;
int left = lower_bound(array.begin(), array.end(), st) - array.begin();
int right = upper_bound(array.begin(), array.end(), ed) - array.begin();
ll sum_left = (left == 0) ? 0 : prefixsum[left - 1];
ll sum_right = (right == 0) ? 0 : prefixsum[right - 1];
return sum_right - sum_left;
}
};
sgt_node operator+(const sgt_node &a, const sgt_node &b) {
sgt_node res;
res.array.resize(a.array.size() + b.array.size());
merge(a.array.begin(), a.array.end(), b.array.begin(), b.array.end(), res.array.begin());
res.prefixsum.resize(res.array.size(), 0);
if (!res.array.empty()) {
res.prefixsum[0] = res.array[0];
for (int i = 1; i < (int)res.prefixsum.size(); i++) {
res.prefixsum[i] = res.prefixsum[i - 1] + res.array[i];
}
}
return res;
}
struct segment_tree {
sgt_node tree[maxn << 2];
int ls(int id) { return id << 1; }
int rs(int id) { return id << 1 | 1; }
void maintain(int id) {
tree[id] = tree[ls(id)] + tree[rs(id)];
}
void build(int id, int left, int right) {
if (left == right) {
int val = wealth[dfs[left]];
tree[id].array = {val};
tree[id].prefixsum = {val};
return;
}
int mid = (left + right) >> 1;
build(ls(id), left, mid);
build(rs(id), mid + 1, right);
maintain(id);
}
sgt_node query(int L, int R, int id = 1, int left = 1, int right = n) {
if (L <= left && right <= R)
return tree[id];
int mid = (left + right) >> 1;
sgt_node res;
if (L <= mid) res = res + query(L, R, ls(id), left, mid);
if (R > mid) res = res + query(L, R, rs(id), mid + 1, right);
return res;
}
} st;
void all_init() {
st.build(1, 1, n);
}
void init(int u, int fath, int dep) {
size[u] = 1;
father[u] = fath;
depth[u] = dep;
heavy_son[u] = 0;
for (int v : adj[u]) {
if (v == fath) continue;
init(v, u, dep + 1);
size[u] += size[v];
if (size[v] > size[heavy_son[u]])
heavy_son[u] = v;
}
}
void another_init(int u, int fath, int head) {
top[u] = head;
cnt++;
dfn[u] = cnt;
dfs[cnt] = u;
if (heavy_son[u]) another_init(heavy_son[u], u, head);
for (int v : adj[u]) {
if (v == fath || v == heavy_son[u]) continue;
another_init(v, u, v);
}
out[u] = cnt;
}
ll chain_query(int x, int y, int start, int end) {
ll res = 0;
while (top[x] != top[y]) {
if (depth[top[x]] < depth[top[y]]) swap(x, y);
res += st.query(dfn[top[x]], dfn[x]).ans(start, end);
x = father[top[x]];
}
if (depth[x] > depth[y]) swap(x, y);
res += st.query(dfn[x], dfn[y]).ans(start, end);
return res;
}
}
using namespace TREE_CHAIN;
void solve() {
// cin >> n >> m;
for (int i = 1; i <= n; i++)
cin >> wealth[i];
for (int i = 1; i < n; i++) {
int u, v;
cin >> u >> v;
adj[u].push_back(v);
adj[v].push_back(u);
}
cnt = 0;
init(root, 0, 1);
another_init(root, 0, root);
all_init();
for (int i = 1; i <= m; i++) {
int s, t, a, b;
cin >> s >> t >> a >> b;
cout << chain_query(s, t, a, b) << " ";
}
cout << "\n";
}
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
cout.tie(nullptr);
int T = 1;
while (cin >> n >> m) solve();
return 0;
}

浙公网安备 33010602011771号