ABC 252 | D - Distinct Trio
题意
给定含有N个元素的数组A,输出满足下列条件的三元组\((i, j, k)\)的数量。
- \(1 \le i < j < k \le N\)
- \(A_i, A_j, A_k\) 各不相同
分析
对于数对计数问题,常用的方法是枚举其中某一个数,然后快速计算选定该数的情况下满足条件的数对个数。由于枚举某个数的复杂度为\(O(N)\),所以一般情况下,计算选定该数后满足条件的数组对数的复杂度为\(O(1)或O(log)\)。为达到快速计算的目的,通常需要预处理出某些值,预处理的手段一般有动态规划、前缀和等,此外,部分项应合并处理。
以下两种方法枚举不同的数值,对应不同的预处理方法。
方法一
根据题意,枚举\(k\)的大小,然后计算对于当前\(k\),满足条件的\((i, j, k)\)的组数\(N_k\),该步骤分析用到了容斥原理。
\[\begin{align*}
N_k &= C_{k - 1}^2 - N_{a_i = a_k \neq a_j} - N_{a_j = a_k \neq a_i} - N_{a_i = a_j \neq a_k} - N_{a_i = a_k = a_j} \\
&= C_{k - 1}^2 - (N_{a_i = a_k \neq a_j} + N_{a_j = a_k \neq a_i}) - N_{a_i = a_j} + N_{a_i = a_j = a_k} - N_{a_i = a_k = a_j} \\
&= C_{k - 1}^2 - (N_{a_i = a_k \neq a_j} + N_{a_j = a_k \neq a_i}) - N_{a_i = a_j}
\end{align*}
\]
下面分别计算上式中各项的值。
- \(C_{k - 1}^2 = (k - 1) * (k - 2) / 2\)
- \(N_{a_i = a_k \neq a_j} + N_{a_j = a_k \neq a_i} = mp[a_k] * (k - 1 - mp[a_k])\),可理解为在\(a_k\)之前选一个与之相等的,有\(mp[a_k]\)个,然后再选一个与\(a_k\)不等的,有\((k - 1 - mp[a_k])\)个,将两个数中下标小的自动赋值给\(a_i\),下标大的自动赋值给\(a_j\)
- \(N_{a_i = a_j}\)需要用\(dp\)来维护,以\(f_i\)表示前\(i\)个数中任选两个且这两个数相等的方案数,那么有\(f(i) = f(i - 1) + mp[a_i], mp[a_i]\)表示前\(i - 1\)中与\(a_i\)相等的数的个数。
- \(N_{a_i = a_j = a_k} = C_{mp[a_k]}^2, mp[a_k]\)表示\(a_k\)之前等于\(a_k\)的数的个数
方法二
首先转化题意,寻找三元组\((i, j, k)(1 \le i < j < k \le N)\)使得\(A_i, A_j, A_k\)各不相同,等价于寻找三元组\((A_i, A_j, A_k),(A_i < A_j < A_k)\)。对于后者,我们可以遍历中间大小的数\(A_j\),对于每一个确定的\(A_j\),对应的三元组个数为\(N_{A_j} = Cntx_{x < A_j} * Cntx_{x > A_j}\),\(Cntx\)可以通过预处理出前缀和从而\(O(1)\)得到。
对于二者的等价性可作如下理解:
- 首先证明每一个\((A_i, A_j, A_k)\)都对应一组\((i, j, k)\),由于所找\(A_i, A_j, A_k\)各不相同,所以可以得到三个互不相同的数组下标,按照大小顺序排好即得\((i, j, k)\)。
- 其次证明每一个\((i, j, k)\)对应一个\((A_i, A_j, A_k)\),不难发现可能存在多对一的情况,举例来说,\((i, j_1, k)\)与\((i, j_2, k)\)可能对应同一组\((A_i, A_j, A_k)\),这意味着我们在遍历\(A_j\)时,对于是对整个数组进行for循环,相同的数值会被重复计算,同理考虑不同的\(i\)对应同一\(A_i\)与不同的\(k\)对应同一\(A_k\)的情况,发现按照以上算法该情况被正确计算
故以上两种表示等价且算法正确。
注意: 计数类问题可能需要开long long。
代码
方法一
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
typedef long long ll;
const int N = 2e5 + 10;
int n;
ll a[N], f[N];
ll mp[N];
int main()
{
scanf("%d", &n);
for(int i = 1; i <= n; i ++) scanf("%d", &a[i]);
ll ans = 0;
for(int i = 1; i <= n; i ++){
int x = a[i];
ans += (ll)(i - 1) * (i - 2) / 2;
ans -= mp[x] * (i - 1 - mp[x]);
ans -= f[i - 1];
f[i] = f[i - 1] + mp[x];
mp[x] ++;
}
printf("%lld\n", ans);
return 0;
}
方法二
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <set>
using namespace std;
const int N = 2e5 + 10;
typedef long long ll;
int a[N];
int n;
ll cnt[N];
int main()
{
scanf("%d", &n);
for(int i = 1; i <= n; i ++){
int x;
scanf("%d", &x);
a[i] = x;
cnt[x] ++;
}
for(int i = 1; i < N; i ++) cnt[i] = cnt[i - 1] + cnt[i];
ll ans = 0;
for(int i = 1; i <= n; i ++) {
int t = a[i];
ans += cnt[t - 1] * (n - cnt[t]);
}
printf("%lld\n", ans);
return 0;
}

浙公网安备 33010602011771号