[NOI Online #2 提高组]子序列问题
嘟嘟嘟
这个NOI网络赛我也不知道是个什么东西,不过据说题目的质量还是挺高的,于是教练就让我做了一下。
讲真我就不会这种硬是要你把\(O(n^2)\)暴力化简的题。
这种题,一种很常见的思路就是依次考虑以\(i\)结尾的所有子区间是怎么算的。
首先肯定要离散化。
接着分两种情况,第一个是\(a_i\)在前\(i-1\)个数中从没出现过,那么对于所有的\(j \in [1,i-1],f(j,i) = f(j,i-1)+1\)。
另一种情况是\(a_i\)在前\(i-1\)个数中出现过,即\(pre_i\)为上一次出现的位置,那么对于所有的\(j \in [pre_i+1,i-1],f(j,i) = f(j,i-1)+1\),而对于\(j \in [1,pre_i], f(j,i) = f(j,i-1)\),即在\(pre_i\)之前没有出现新的数。
于是我们就有一个\(O(n^2)\)的递推\(f(l,r)\)的算法:
首先有\(f(i,i)=1\)。
当\(j \in [1,pre_i]\)时,有\(f(j,i)=f(j, i-1)\);
当\(j \in [pre_i + 1, i-1]\)时,有\(f(j,i) = f(j, i - 1)+1\)。
然后我们可以用线段树进行优化。
对于\(f(j,i)\)来说,是区间每次加1,但是我们要求的是\(\sum (f(j,i))^2\),所以我们要简单的推一下:
对于\(m\)个数\(a_1, a_2, \ldots a_m\),如果每一个数都加\(n\),那么他们的平方和就变成了\((a_1+n)^2 + (a_2+n) ^ 2 \ldots (a_m+n)^2\)。
展开后,和原来\(\sum_{i=1}^{m}a_i ^ 2\)的差值就是\(m * n ^ 2 + 2 * n * \sum_{i=1}^{m} a_i\),这个就可以用线段树区间和维护了。
时间复杂度\(O(nlogn)\),因为这种写法只用一次修改,所以常数能稍稍小一点。
#include<cstdio>
#include<iostream>
#include<cmath>
#include<algorithm>
#include<cstring>
#include<cstdlib>
#include<cctype>
#include<vector>
#include<queue>
#include<assert.h>
#include<ctime>
using namespace std;
#define enter puts("")
#define space putchar(' ')
#define Mem(a, x) memset(a, x, sizeof(a))
#define In inline
#define forE(i, x, y) for(int i = head[x], y; ~i && (y = e[i].to); i = e[i].nxt)
typedef long long ll;
typedef double db;
const int INF = 0x3f3f3f3f;
const db eps = 1e-8;
const int maxn = 1e6 + 5;
const ll mod = 1e9 + 7;
In ll read()
{
ll ans = 0;
char ch = getchar(), las = ' ';
while(!isdigit(ch)) las = ch, ch = getchar();
while(isdigit(ch)) ans = (ans << 1) + (ans << 3) + ch - '0', ch = getchar();
if(las == '-') ans = -ans;
return ans;
}
In void write(ll x)
{
if(x < 0) x = -x, putchar('-');
if(x >= 10) write(x / 10);
putchar(x % 10 + '0');
}
In void MYFILE()
{
#ifndef mrclr
freopen(".in", "r", stdin);
freopen(".out", "w", stdout);
#endif
}
int n, a[maxn];
int li[maxn], _n;
int pos[maxn], pre[maxn];
In void init()
{
sort(li + 1, li + n + 1);
_n = unique(li + 1, li + n + 1) - li - 1;
for(int i = 1; i <= n; ++i) a[i] = lower_bound(li + 1, li + _n + 1, a[i]) - li;
for(int i = 1; i <= n; ++i) pre[i] = pos[a[i]], pos[a[i]] = i;
}
In ll ADD(ll a, ll b) {return a + b < mod ? a + b : a + b - mod;}
int l[maxn << 2], r[maxn << 2];
ll sum[maxn << 2], dat[maxn << 2], lzy[maxn << 2];
In void build(int L, int R, int now)
{
l[now] = L, r[now] = R;
if(L == R) return;
int mid = (L + R) >> 1;
build(L, mid, now << 1), build(mid + 1 , R, now << 1 | 1);
}
In void change(int now, ll d)
{
int len = r[now] - l[now] + 1;
lzy[now] += d;
dat[now] = ADD(dat[now], ADD(1LL * len * d % mod * d % mod, (d * sum[now] % mod << 1) % mod));
sum[now] = ADD(sum[now], d * len % mod);
}
In void pushdown(int now)
{
if(lzy[now])
{
change(now << 1, lzy[now]);
change(now << 1 | 1, lzy[now]);
lzy[now] = 0;
}
}
In void update(int L, int R, int now)
{
if(l[now] == L && r[now] == R) {change(now, 1); return;}
pushdown(now);
int mid = (l[now] + r[now]) >> 1;
if(R <= mid) update(L, R, now << 1);
else if(L > mid) update(L, R, now << 1 | 1);
else update(L, mid, now << 1), update(mid + 1, R, now << 1 | 1);
sum[now] = ADD(sum[now << 1], sum[now << 1 | 1]);
dat[now] = ADD(dat[now << 1], dat[now << 1 | 1]);
}
int main()
{
// MYFILE();
n = read();
for(int i = 1; i <= n; ++i) li[i] = a[i] = read();
init();
build(1, n, 1);
ll ans = 0;
for(int i = 1; i <= n; ++i) update(pre[i] + 1, i, 1), ans = ADD(ans, dat[1]);
write(ans), enter;
return 0;
}