[Non]区间平方和
[Non]区间平方和
大意
需要完成区间修改和区间查询平方和的功能。
思路
首先我们考虑用线段树维护,然后想想 pushup 和 pushdown 怎么写?
定义:
sum, sumq。
显然有 pushup:
t[u].sum = t[lc].sum + t[rc].sum;
t[u].sumq = t[lc].sumq + t[rc].sumq;
然后考虑 pushdown,若是对于一段区间全部加上 \(k\) 会发生什么?
不妨设区间长度为 \(2\),即 \(a, b\),原 \(\text{sumq} = a ^ 2 + b ^ 2\),现在为 \((a + k) ^ 2 + (b + k) ^ 2 = a ^ 2 + b ^ 2 + 2ak + 2bk + 2k^2 = a ^ 2 + b ^ 2 + 2k(a + b) + 2 k^2\)。
而 \(a + b = \text{sum}\),即 \(\text{sum} + k ^ 2 \times \text{len} + 2 \times \text{sum} \times k\) 故我们有 pushdown 如下:
if(t[u].add){
t[lc].sumq += (t[u].add * t[u].add * (t[lc].r - t[lc].l + 1) + 2 * t[lc].sum * t[u].add);
t[rc].sumq += (t[u].add * t[u].add * (t[rc].r - t[rc].l + 1) + 2 * t[rc].sum * t[u].add);
t[lc].sum += t[u].add * (t[lc].r - t[lc].l + 1);
t[rc].sum += t[u].add * (t[rc].r - t[rc].l + 1);
t[lc].add += t[u].add;
t[rc].add += t[u].add;
t[u].add = 0;
}
然后就没了。
代码
#include<iostream>
using namespace std;
#define lc u << 1
#define rc u << 1 | 1
const int MAXN = 1e5 + 5;
struct node{
int l, r;
long long add;
long long sum, sumq;
}t[MAXN * 4];
void pushup(int u){
t[u].sum = t[lc].sum + t[rc].sum;
t[u].sumq = t[lc].sumq + t[rc].sumq;
}
// a ^ 2 + b ^ 2 + c ^ 2
// (a + k) ^ 2 + (b + k) ^ 2 + (c + k) ^ 2
// += k ^ 2 * len + 2 * sum * k
void pushdown(int u){
if(t[u].add){
t[lc].sumq += (t[u].add * t[u].add * (t[lc].r - t[lc].l + 1) + 2 * t[lc].sum * t[u].add);
t[rc].sumq += (t[u].add * t[u].add * (t[rc].r - t[rc].l + 1) + 2 * t[rc].sum * t[u].add);
t[lc].sum += t[u].add * (t[lc].r - t[lc].l + 1);
t[rc].sum += t[u].add * (t[rc].r - t[rc].l + 1);
t[lc].add += t[u].add;
t[rc].add += t[u].add;
t[u].add = 0;
}
}
void build(int u, int l, int r){
t[u] = {l, r, 0, 0, 0};
if(l == r) return;
int mid = (l & r) + ((l ^ r) >> 1);
build(lc, l, mid);
build(rc, mid + 1, r);
pushup(u);
}
void update(int u, int l, int r, int k){
if(l <= t[u].l && t[u].r <= r){
t[u].add += k;
t[u].sumq += (2 * k * t[u].sum + k * k * (t[u].r - t[u].l + 1));
t[u].sum += k * (t[u].r - t[u].l + 1);
return;
}
int mid = (t[u].l & t[u].r) + ((t[u].l ^ t[u].r) >> 1);
pushdown(u);
if(l <= mid){
update(lc, l, r, k);
}
if(r > mid){
update(rc, l, r, k);
}
pushup(u);
}
long long query(int u, int l, int r){
if(l <= t[u].l && t[u].r <= r){
return t[u].sumq;
}
int mid = (t[u].l & t[u].r) + ((t[u].l ^ t[u].r) >> 1);
pushdown(u);
long long ans = 0;
if(l <= mid){
ans += query(lc, l, r);
}
if(r > mid){
ans += query(rc, l, r);
}
return ans;
}
int n, m;
int main(){
ios::sync_with_stdio(0);
cin.tie(0);
cin >> n >> m;
build(1, 1, n);
for(int i = 1;i <= m;i ++){
string op;
int l, r, k;
cin >> op;
if(op == "Update"){
cin >> l >> r >> k;
update(1, l, r, k);
}
else{
cin >> l >> r;
cout << query(1, l, r) << '\n';
}
}
return 0;
}
本文来自一名高中生,作者:To_Carpe_Diem

浙公网安备 33010602011771号