Codeforces 1637E. Best Pair
题目大意
有一个长为 \(n(2\leq n\leq3\times10^5)\) 的序列 \(a(1\leq a_i\leq10^9)\) ,有 \(m(0\leq m\leq3\times10^5)\) 个坏的无序数对 \((x_i,y_i)\) 。设 \(f(x,y) = (x+y)(cnt_x+cnt_y)\) ,求序列中好的无数序对中最大的 \(f(x,y)\) , \(x\neq y\)。
思路
考虑到 \(cnt_x\) 的个数最多只有 \(\sqrt n\) 个,我们可以枚举所有的 \((cnt_x,cnt_y)\) 。枚举的时间复杂度为 \(O(n)\) 。我们可以先离散化,计算出 \(cnt_{a_i}\) ,之后预处理出每个 \(cnt_i\) 下有哪些数,并且将它们降序排序,之后对于每个 \((cnt_x,cnt_y)\) ,先枚举 \(cnt_x\) 中的数字 \(u\) ,对于每个 \(u\) 再去枚举 \(cnt_y\) 中的数字 \(v\) ,如果 \(u=v\) 或者是 \((u,v)\) 为坏的数对就继续枚举,否则可以更新 \(ans\) 并且退出到外层继续枚举,同时在判断是否为坏的数对之前可以提前与 \(ans\) 比较,如果无法更新可以提前退出到外层,我们用 \(set\) 来维护所有的坏的数对,这样枚举具体的数对时只有枚举到坏的数对才会继续枚举,所以枚举数对这一部分总的复杂度是 \(O(mlogm)\) 的,总的复杂度为 \(O(nlogn+mlogm)\) 。
代码
#include<bits/stdc++.h>
#include<unordered_map>
#include<unordered_set>
using namespace std;
using LL = long long;
using ULL = unsigned long long;
using PII = pair<int, int>;
using TP = tuple<int, int, int>;
#define all(x) x.begin(),x.end()
#define mk make_pair
#define int LL
//#define lc p*2
//#define rc p*2+1
#define endl '\n'
#define inf 0x3f3f3f3f
#define INF 0x3f3f3f3f3f3f3f3f
#pragma warning(disable : 4996)
#define IOS ios::sync_with_stdio(0),cin.tie(0),cout.tie(0)
const double eps = 1e-8;
const LL MOD = 1000000007;
const LL mod = 998244353;
const int maxn = 300010;
LL T, N, M, A[maxn], B[maxn], H[maxn], cnt[maxn];
vector<LL>cnts[maxn], S;
set<PII>bads;
bool cmp(const LL& a, const LL& b)
{
return a > b;
}
void init()
{
for (int i = 1; i <= N; i++)
cnt[i] = 0, cnts[i].clear();
S.clear(), bads.clear();
}
int compress()
{
vector<LL>xs;
for (int i = 1; i <= N; i++)
xs.push_back(A[i]);
sort(all(xs));
xs.erase(unique(all(xs)), xs.end());
for (int i = 1; i <= N; i++)
{
B[i] = upper_bound(all(xs), A[i]) - xs.begin();
H[B[i]] = A[i];
}
return xs.size();
}
void solve()
{
int x, y;
LL ans = -INF;
int K = compress();
for (int i = 1; i <= M; i++)
cin >> x >> y, bads.insert(PII(min(x, y), max(x, y)));
for (int i = 1; i <= N; i++)
cnt[B[i]]++;
for (int i = 1; i <= K; i++)
cnts[cnt[i]].push_back(i), S.push_back(cnt[i]);
sort(all(S));
S.erase(unique(all(S)), S.end());
for (int i = 0; i < S.size(); i++)
{
if (cnts[S[i]].size() > 1)
sort(all(cnts[S[i]]), cmp);
}
for (int i = 0; i < S.size(); i++)
{
for (int j = i; j < S.size(); j++)
{
for (auto& a : cnts[S[i]])
{
for (auto& b : cnts[S[j]])
{
if (a == b)
continue;
int x = H[a], y = H[b];
if ((x + y) * (S[i] + S[j]) <= ans)
break;
if (bads.count(PII(min(x, y), max(x, y))) > 0)
continue;
ans = (x + y) * (S[i] + S[j]);
break;
}
}
}
}
cout << ans << endl;
}
signed main()
{
IOS;
cin >> T;
while (T--)
{
cin >> N >> M;
init();
for (int i = 1; i <= N; i++)
cin >> A[i];
solve();
}
return 0;
}