51nod 1712 区间求和

题意

给你一个长度为 \(n\) 的序列,定义这个序列的权值为 $ \sum_{1 \leq i<j \leq n} a_j - a_i $。
现在给你一个长度为 \(n\) 的序列,当$ a_i=a_j $时,将 $ a_i, a_{i+1}, ... , a_j $ 提出当做一个序列,计算它的权值。统计所有这样的区间的权值和。答案模 \(2^{32}\)

想法

先考虑对于原题的做法,即对于区间 \([i, j]\) ,每一位都要乘一个系数,即算每一位的系数大小即可。对于 \(i \leq k \leq j\) ,系数 \(= -(j-k) + (k-i) = 2k-(i+j)\) 。回到这道题,对于 \(a_k\),有多少个区间 \([i, j](a_i=a_j, i \leq j)\) 包含了 \(a_k\) 就要加多少个 \(2k\),且要减去这些区间的左端点和右端点的值。
由于减去右端点的值相当于将整个序列翻转然后去减左端点的值,故这里我们只考虑左端点。减右端点的情况我们倒着跑一边即可。
首先我们可以预处理出每个数从左往右(\(lrk\))/从右往左(\(rrk\))数是第几次出现,它的前驱(\(pre\))/后继(\(next\))分别在哪。
方便起见,我们不妨考虑 \([i, i]\) 这样的区间也合法。考虑序列3 3 3 3\(f_i\) 表示要减的左端点的值的和。\(f_1 = 1*4 = 4, f_2 = f_1 + 2*3 - 1 = 9, f_3 = f_2 + 3*2 - (1 + 2)=12, f_4 = f_3 + 4*1 - (1 + 2 + 3)=10\) 。观察可知每次加上了一个 \(i*rrk[i]\) ,又减掉了 \(i\) 之前出现的值为 \(a_i\) 的下标和。为什么呢?因为我们一开始假想 \(i\) 到序列末尾都覆盖上了 \(i\) ,中间遇到一个 \(a_i\) 的时候就要让其中的一个区间停止更新。那么我们定义 \(decA[i]\) 表示 \(i\) 之前出现的值为 \(a_i\) 的下标和,转移方程为 \(decA[i] = decA[pre[i]] + pre[i]\)。一般的,33之间还有很多其他的数,那么我们减掉 \(decA[i]\) 的时候不是下一次 \(a[i]\) 出现的时候,而是 \(i+1\) 就必须减掉,不然覆盖的区间就是 \([i, next[i])\) 了,实际上应该是 \([i, i]\)
故可以得出转移方程 $$f[i] = f[i-1] + i*rrk[i] - decA[next[i-1]]$$
这里有一个问题,就是如果 \(rrk[i] = 1\) 的时候 \(next[i]\) 不存在, 减 \(decA[next[i]]\) 的时候会出错。那我们自己造一个 \(next[i]\) 嘛...特判一下就好了。
最后是 \(2k\) 的个数,由于以上我们统计的是左端点的值的总和,那么我们把这个值改成 \(1\) 不就是区间的个数了吗。 \(decA[i]\) 的转移方程改成 \(decA[i] = decA[pre[i]] + 1\) 即可。
怕gg所以写了个快速乘...不过好像不用也没问题...
记得用unsigned int存答案,最好写个读入优化。

Code

#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <iostream>
#define ll long long 
#define db double
#define uint unsigned int
#define N 3000100
#define max(a, b) ((a) > (b) ? (a) : (b))
#define min(a, b) ((a) < (b) ? (a) : (b))
#define swap(T, a, b) ({T ttt = a; a = b; b = ttt;})
int n, lrk[N], rrk[N], pre[N], next[N], pos[N];
uint a[N], decA[N], decB[N], f[N], g[N], Ans = 0;
void G(uint &w) {
   w = 0; char c = getchar();
   while (c > '9' || c < '0') c = getchar();
   while (c >= '0' && c <= '9') { w = w * 10 + c - '0'; c = getchar(); }
}
uint Mult(uint a, uint b)
{
   uint s = 0; 
   while (b) {
      if (b & 1) s += a;
      a += a; b >>= 1;
   }
   return s;
}
int main()
{
   scanf("%d", &n);
   for (int i = 1; i <= n; i++)
      G(a[i]);
   memset(pos, 0, sizeof(pos));
   for (int i = 1; i <= n; i++)
   {
      pre[i] = pos[a[i]]; pos[a[i]] = i;
      lrk[i] = lrk[pre[i]] + 1;
      decA[i] = decA[pre[i]] + pre[i];
   }
   memset(pos, 0, sizeof(pos)); 
   for (int i = n; i >= 1; i--)
   {
      next[i] = pos[a[i]]; pos[a[i]] = i;
      rrk[i] = rrk[next[i]] + 1;
      decB[i] = decB[next[i]] + next[i]; 
   }
   for (int i = 1; i <= n; i++)
   {
      if (rrk[i] == 1)
      {
    	 next[i] = i+n; pre[i+n] = i;
    	 lrk[next[i]] = lrk[i] + 1; 
    	 decA[i+n] = decA[i] + i;
      }
      if (lrk[i] == 1)
      {
    	 pre[i] = i+n*2; next[i+n*2] = i;
    	 rrk[pre[i]] = rrk[i] + 1; 
    	 decB[i+n*2] = decB[i] + i;
      }
   }
   for (int i = 1; i <= n; i++)
      f[i] = f[i-1] + i*rrk[i] - decA[next[i-1]];
   for (int i = n; i >= 1; i--)
      g[i] = g[i+1] + i*lrk[i] - decB[pre[i+1]];
   for (int i = 1; i <= n; i++)
      f[i] = f[i] + g[i] - i*2; 
   memset(g, 0, sizeof(g));
   memset(decA, 0, sizeof(decA));
   for (int i = 1; i <= n; i++)
   {
      decA[i] = decA[pre[i]] + (pre[i] >= 1 && pre[i] <= n);
      if (rrk[i] == 1) decA[i+n] = decA[i] + 1;
   }
   for (int i = 1; i <= n; i++)
   {
      g[i] = g[i-1] + rrk[i] - decA[next[i-1]];
      f[i] = Mult(g[i] - 1, i) * 2 - f[i];
      Ans += Mult(f[i], a[i]); 
   }
   std::cout << Ans << std::endl; 
   return 0; 
}
posted @ 2016-11-28 09:24  zkGaia  阅读(384)  评论(0编辑  收藏  举报