BZOJ 2752 [HAOI2012]高速公路(road):线段树【维护区间内子串和】

题目链接:http://www.lydsy.com/JudgeOnline/problem.php?id=2752

题意:

  有一个初始全为0的,长度为n的序列a。

  有两种操作:

    (1)C l r v: 将[l,r)内的数全部加v。

    (2)Q l r: 在[l,r)内随机选两个数x,y(x < y),问你∑(a[x to y])的期望,用最简分数形式输出。

 

题解:

  首先,题中要求的期望 = 区间内所有子串之和 / 区间内子串个数。

  如果一个区间的长度为len,显然区间内的子串个数为len*(len+1)/2。

  所以题目就变成了怎样维护区间内所有子串之和。

 

  dat表示某个区间的子串和。

  假设有两个相邻区间l,r,合并起来的区间叫x。

  那么dat[x] = dat[x] + dat[y] + 跨两个区间的子串和

 

  所以接下来考虑如何求跨区间的子串和。

  sum表示某个区间的所有元素之和。

  ln表示区间l的长度,rn表示区间r的长度。

  ls表示某个区间的所有所有前缀之和,rs表示某个区间的所有后缀之和。

  则跨区间的子串之和 = rs[l]*rn + ls[r]*ln

  即dat[x] = dat[x] + dat[y] + rs[l]*rn + ls[r]*ln

 

  ls,rs和sum的合并就很好求了:

  ls[x] = ls[l] + rn*sum[l] + ls[r]

  rs[x] = rs[r] + ln*sum[r] + rs[l]

  sum[x] = sum[l] + sum[r]

 

  这样线段树的pushup函数就写完了。

  然后考虑如何pushdown传标记。

 

  tag表示某个区间被同时加了多少。

  现在只考虑当前节点x的某一个儿子y,儿子y的区间长度为len。

 

  首先考虑tag[x]对dat[y]的贡献。

  贡献 = 枚举子串的长度 * 这种长度的子串个数 * tag[x]

  即:dat[y] += ∑ i*(len-i+1)*tag[x],其中i∈[1,len]。

  化简得:dat[y] += ( len*(len+1)/2*(len+1) + ∑(i^2) ) * tag[x]

  对于其中的∑(i^2),事先O(n)预处理出来一个平方前缀和数组sqr即可。

 

  然后易得tag[x]对ls,rs,sum的贡献:

  ls[y] += len*(len+1)/2*tag[x]

  rs[y] += len*(len+1)/2*tag[x]

  sum[y] += len*tag[x]

 

  这样pushdown也就写好了。

  然后大力线段树即可QAQ……

 

AC Code:

  1 #include <iostream>
  2 #include <stdio.h>
  3 #include <string.h>
  4 #include <algorithm>
  5 #define MAX_N 100005
  6 #define MAX_T 400005
  7 #define int ll
  8 
  9 using namespace std;
 10 
 11 typedef long long ll;
 12 
 13 struct Node
 14 {
 15     int dt,ls,rs,s,ln;
 16     Node(int _dt,int _ls,int _rs,int _s,int _ln)
 17     {
 18         dt=_dt; ls=_ls; rs=_rs; s=_s; ln=_ln;
 19     }
 20     Node(){}
 21     friend Node mix(const Node &a,const Node &b)
 22     {
 23         int _dt=a.dt+b.dt+a.rs*b.ln+b.ls*a.ln;
 24         int _ls=a.ls+b.ln*a.s+b.ls;
 25         int _rs=b.rs+a.ln*b.s+a.rs;
 26         int _s=a.s+b.s;
 27         int _ln=a.ln+b.ln;
 28         return Node(_dt,_ls,_rs,_s,_ln);
 29     }
 30 };
 31 
 32 int n,m;
 33 int ls[MAX_T];
 34 int rs[MAX_T];
 35 int dat[MAX_T];
 36 int sum[MAX_T];
 37 int tag[MAX_T];
 38 int sqr[MAX_N];
 39 
 40 void cal_sqr()
 41 {
 42     for(int i=1;i<=n;i++) sqr[i]=sqr[i-1]+i*i;
 43 }
 44 
 45 void push_up(int x,int len)
 46 {
 47     int l=x*2+1,r=x*2+2;
 48     Node L(dat[l],ls[l],rs[l],sum[l],len-(len>>1));
 49     Node R(dat[r],ls[r],rs[r],sum[r],(len>>1));
 50     Node tmp=mix(L,R);
 51     dat[x]=tmp.dt;
 52     ls[x]=tmp.ls;
 53     rs[x]=tmp.rs;
 54     sum[x]=tmp.s;
 55 }
 56 
 57 void push_down(int x,int len)
 58 {
 59     if(tag[x])
 60     {
 61         int l=x*2+1,r=x*2+2;
 62         int ln=(len-(len>>1)),rn=(len>>1);
 63         dat[l]+=(ln*(ln+1)/2*(ln+1)-sqr[ln])*tag[x];
 64         dat[r]+=(rn*(rn+1)/2*(rn+1)-sqr[rn])*tag[x];
 65         ls[l]+=ln*(ln+1)/2*tag[x];
 66         ls[r]+=rn*(rn+1)/2*tag[x];
 67         rs[l]+=ln*(ln+1)/2*tag[x];
 68         rs[r]+=rn*(rn+1)/2*tag[x];
 69         sum[l]+=ln*tag[x];
 70         sum[r]+=rn*tag[x];
 71         tag[l]+=tag[x];
 72         tag[r]+=tag[x];
 73         tag[x]=0;
 74     }
 75 }
 76 
 77 void update(int a,int b,int k,int l,int r,int x)
 78 {
 79     if(a<=l && r<=b)
 80     {
 81         int len=r-l+1;
 82         tag[k]+=x;
 83         sum[k]+=len*x;
 84         ls[k]+=len*(len+1)/2*x;
 85         rs[k]+=len*(len+1)/2*x;
 86         dat[k]+=(len*(len+1)/2*(len+1)-sqr[len])*x;
 87         return;
 88     }
 89     if(r<a || b<l) return;
 90     push_down(k,r-l+1);
 91     int mid=(l+r)>>1;
 92     update(a,b,k*2+1,l,mid,x);
 93     update(a,b,k*2+2,mid+1,r,x);
 94     push_up(k,r-l+1);
 95 }
 96 
 97 Node query(int a,int b,int k,int l,int r)
 98 {
 99     if(a<=l && r<=b) return Node(dat[k],ls[k],rs[k],sum[k],r-l+1);
100     if(r<a || b<l) return Node(0,0,0,0,0);
101     push_down(k,r-l+1);
102     int mid=(l+r)>>1;
103     Node v1=query(a,b,k*2+1,l,mid);
104     Node v2=query(a,b,k*2+2,mid+1,r);
105     return mix(v1,v2);
106 }
107 
108 signed main()
109 {
110     scanf("%lld%lld",&n,&m);
111     n--;
112     cal_sqr();
113     char opt[16];
114     int l,r,v;
115     while(m--)
116     {
117         scanf("%s%lld%lld",opt,&l,&r);
118         if(opt[0]=='C')
119         {
120             scanf("%lld",&v);
121             update(l,r-1,0,1,n,v);
122         }
123         else
124         {
125             int dt=query(l,r-1,0,1,n).dt;
126             int len=r-l;
127             int tot=len*(len+1)/2;
128             int g=__gcd(dt,tot);
129             printf("%lld/%lld\n",dt/g,tot/g);
130         }
131     }
132 }

 

posted @ 2018-03-12 23:26  Leohh  阅读(217)  评论(0编辑  收藏  举报