线段树维护方差
题面
方差
方差2
一道适合巩固带 tag 标记的线段树写法的题。
题解
此题解为 t1 题解,t2更为简单一些,不过可以练一下除法取模。
平均数维护很简单吧,只需要维护区间和即可。
方差公式为:
\[\frac{1}{n} * \sum_{i = l} ^ {r} (a_i - \bar a) ^ 2
\]
这样我们肯定很难维护,于是我们尝试把平方打开,就会发现这道题迎刃而解了:
\[\sum_{i = l} ^ {r} (a_i - \bar a) ^ 2\ = \ (a_{l} ^ 2 + a_{l + 1} ^ 2 + … + a_{r} ^ 2) - 2 * \bar a *(a_l + a_{l + 1} + … + a_r) - (r - l + 1) * \bar a ^ 2
\]
不难发现中间是我们维护的区间和,后面可以直接计算,所以我们就只需要再维护一个区间平方和即可。
维护区间平方和的方法和维护区间和的方法类似,向下传递 tag 标记和更新时,和上述的展开一样,所以每次要更新的时候,一定要先更新平方和的值,不然中间部分的区间和就变了,答案自然也就不对。
最后这道题还有一个小细节,如果想偷偷懒只写一个返回值为结构体的 query 的话,记得结构体里或者声明它的时候要初始化,不然它会被随机赋值,就会全 WA 掉。
代码
#include<cstdio>
const int N = 1e5 + 5;
int n,m;
double a[N];
struct SegmentTree {
#define M N << 2
int l[M],r[M];
double sum[M],pow[M],tag[M];
inline void pushup(int p) {
sum[p] = sum[p << 1] + sum[p << 1 | 1];
pow[p] = pow[p << 1] + pow[p << 1 | 1];
}
void build(int p,int lf,int rg) {
l[p] = lf,r[p] = rg;
if(lf == rg) {
sum[p] = a[lf]; pow[p] = a[lf] * a[lf];
return ;
}
int mid = (lf + rg) >> 1;
build(p << 1,lf,mid);
build(p << 1 | 1,mid + 1,rg);
pushup(p);
}
inline void pushdown(int p) {
if(!tag[p]) return ;
pow[p << 1] += 2.0 * tag[p] * sum[p << 1] + (r[p << 1] - l[p << 1] + 1.0) * tag[p] * tag[p];
pow[p << 1 | 1] += 2.0 * tag[p] * sum[p << 1 | 1] + (r[p << 1 | 1] - l[p << 1 | 1] + 1.0) * tag[p] * tag[p];
sum[p << 1] += (r[p << 1] - l[p << 1] + 1.0) * tag[p];
sum[p << 1 | 1] += (r[p << 1 | 1] - l[p << 1 | 1] + 1.0) * tag[p];
tag[p << 1] += tag[p]; tag[p << 1 | 1] += tag[p]; tag[p] = 0.0;
}
inline void update(int p,int L,int R,double k) {
if(L <= l[p] && r[p] <= R) {
pow[p] += 2.0 * k * sum[p] + (r[p] - l[p] + 1.0) * k * k;
sum[p] += k * (r[p] - l[p] + 1.0); tag[p] += k;
return ;
}
pushdown(p);
int mid = (l[p] + r[p]) >> 1;
if(L <= mid) update(p << 1,L,R,k);
if(R > mid) update(p << 1 | 1,L,R,k);
pushup(p);
}
inline double query_sum(int p,int L,int R) {
if(L <= l[p] && r[p] <= R) return sum[p];
pushdown(p);
int mid = (l[p] + r[p]) >> 1; double res = 0.0;
if(L <= mid) res = query_sum(p << 1,L,R);
if(R > mid) res += query_sum(p << 1 | 1,L,R);
return res;
}
inline double query_pow(int p,int L,int R) {
if(L <= l[p] && r[p] <= R) return pow[p];
pushdown(p);
int mid = (l[p] + r[p]) >> 1; double res = 0.0;
if(L <= mid) res = query_pow(p << 1,L,R);
if(R > mid) res += query_pow(p << 1 | 1,L,R);
return res;
}
}tr;
#undef M
int main() {
scanf("%d%d",&n,&m);
for(int i = 1; i <= n; i++) scanf("%lf",&a[i]);
tr.build(1,1,n);
for(int i = 1; i <= m; i++) {
int opr,x,y; scanf("%d%d%d",&opr,&x,&y);
if(opr == 1) {
double k; scanf("%lf",&k);
tr.update(1,x,y,k);
}
else {
double Sum = tr.query_sum(1,x,y);
double ave = Sum / (y - x + 1.0);
if(opr == 2) printf("%.4lf\n",ave);
else {
double Pow = tr.query_pow(1,x,y) / (y - x + 1.0);
printf("%.4lf\n",(Pow - ave * ave));
}
}
}
return 0;
}
t2代码
把 build 里写错了,调了好久才发现,千万不要犯这种错误。
#include<cstdio>
const int N = 1e5 + 5,mod = 1e9 + 7;
int a[N],n,m;
struct SegmentTree {
#define M N << 2
int l[M],r[M],sum[M],pow[M];
inline void pushup(int p) {
pow[p] = (1ll * pow[p << 1] + pow[p << 1 | 1]) % mod;
sum[p] = (1ll * sum[p << 1] + sum[p << 1 | 1]) % mod;
}
void build(int p,int lf,int rg) {
l[p] = lf; r[p] = rg;
if(lf == rg) {
sum[p] = a[lf] % mod; pow[p] = 1ll * a[lf] * a[lf] % mod;
return ;
}
int mid = (lf + rg) >> 1;
build(p << 1,lf,mid);
build(p << 1 | 1,mid + 1,rg);
pushup(p);
}
void update(int p,int pos,int k) {
if(l[p] == pos && r[p] == pos) {
sum[p] = k; pow[p] = 1ll * k * k % mod;
return ;
}
int mid = (l[p] + r[p]) >> 1;
if(pos <= mid) update(p << 1,pos,k);
else update(p << 1 | 1,pos,k);
pushup(p);
}
int query_sum(int p,int L,int R) {
if(L <= l[p] && r[p] <= R) return sum[p];
int mid = (l[p] + r[p]) >> 1; int res = 0;
if(L <= mid) res = query_sum(p << 1,L,R) % mod;
if(R > mid) res = (1ll * res + query_sum(p << 1 | 1,L,R)) % mod;
return res;
}
int query_pow(int p,int L,int R) {
if(L <= l[p] && r[p] <= R) return pow[p];
int mid = (l[p] + r[p]) >> 1; int res = 0;
if(L <= mid) res = query_pow(p << 1,L,R) % mod;
if(R > mid) res = (1ll * res + query_pow(p << 1 | 1,L,R)) % mod;
return res;
}
}tr;
#undef M
int power(int a,int b = mod - 2,int ans = 1) {
for(; b; b >>= 1,a = 1ll * a * a % mod)
if(b & 1) ans = 1ll * ans * a % mod;
return ans;
}
inline int read() {
int x = 0,flag = 1;
char ch = getchar();
while(ch < '0' || ch > '9'){if(ch == '-')flag = -1;ch = getchar();}
while(ch >='0' && ch <='9'){x = (x << 3) + (x << 1) + ch - 48;ch = getchar();}
return x * flag;
}
int main() {
n = read(),m = read();
for(int i = 1; i <= n; i++) a[i] = read();
tr.build(1,1,n);
for(int i = 1; i <= m; i++) {
int opr = read(),x = read(),y = read();
if(opr == 1) tr.update(1,x,y % mod);
else {
int inv = power(y - x + 1);
int ave = 1ll * tr.query_sum(1,x,y) * inv % mod;
int Pow = 1ll * tr.query_pow(1,x,y) * inv % mod;
printf("%d\n",((1ll * Pow - 1ll * ave * ave % mod) % mod + mod) % mod);
}
}
return 0;
}

浙公网安备 33010602011771号