[均摊复杂度] Codeforces 1637E Best Pair

题目大意

给定一个长度为 \(n\) 的数组 \(a\)。给出如下定义:

  • 定义 \(cnt_x\)​ 为 \(x\) 在数组 \(a\) 中出现的次数
  • 定义 \(f(x,y)=(cnt_x+cnt_y)\cdot(x+y)\)

同时,给定由 \(m\)无序数对 \((x_1,y_1),(x_2,y_2),\cdots,(x_m,y_m)\) 组成的集合 \(S\)

你需要计算出 \(\max ⁡f(u,v)\),要求 \(u,v\) 都出现在数组中,\(u\neq v\),且 \((u,v)\) 不在集合 \(S\) 中。

数据范围:

  • \(t\) 组数据,\(1\leq t\leq 10^4\)
  • \(2\leq n,\sum n\leq 3×10^5,0\leq m,\sum m\leq 3×10^5\)
  • \(1\leq a_i\leq 10^9\)

题解

学到了一个小trick,本题可以使用均摊复杂度的方法,使得看似暴力的方法通过。

我们可以固定 \(x\),枚举 \(cnt_y\leq cnt_x\)\(cnt_y\)。对于一个 \(x\),枚举到的 \(cnt_y\) 至多有 \(cnt_x\) 个,因此这一步枚举复杂度为 \(O(\sum cnt_x)=O(n)\)

当固定了 \(x\)\(cnt_y\) 时,我们去按降序枚举出现次数等于 \(cnt_y\) 的所有的 \(y\),若 \(x=y\)\((x,y)\) 被禁止,则跳过这个 \(y\),继续去枚举 \(y\),直到找到第一个合法的 \(y\),更新答案,并退出循环。遇到 \(x=y\)\((x,y)\) 被禁止的情况是 \(O(n+m)\)的。判断 \((x,y)\) 是否被禁止可以使用 set。不妨假设 \(n,m\) 同阶,于是时间复杂度为 \(O(n\log n)\)

Code

#include <bits/stdc++.h>
using namespace std;

#define LL long long

template<typename elemType>
inline void Read(elemType& T) {
    elemType X = 0, w = 0; char ch = 0;
    while (!isdigit(ch)) { w |= ch == '-';ch = getchar(); }
    while (isdigit(ch)) X = (X << 3) + (X << 1) + (ch ^ 48), ch = getchar();
    T = (w ? -X : X);
}

map<int, int> cnt;
vector<int> vec[300010];
set<pair<int, int>> s;
int a[300010];
int T, n, m;

LL solve() {
    LL ans = 0;
    for (auto v : cnt) {
        int x = v.first;
        for (int i = 1;i <= v.second;++i) {
            for (auto y : vec[i]) {
                if (x == y || s.count(make_pair(min(x, y), max(x, y)))) continue;
                ans = max(ans, 1LL * (x + y) * (v.second + i));
                break;
            }
        }
    }
    return ans;
}

int main() {
    Read(T);
    while (T--) {
        Read(n); Read(m);
        cnt.clear(); s.clear();
        for (int i = 1;i <= n;++i) { Read(a[i]); ++cnt[a[i]]; vec[i].clear(); }
        for (auto v : cnt)
            vec[v.second].push_back(v.first);
        for (int i = 1;i <= n;++i)
            if (!vec[i].empty()) sort(vec[i].begin(), vec[i].end(), greater<int>());
        for (int i = 1;i <= m;++i) {
            int x, y; Read(x); Read(y);
            s.insert(make_pair(min(x, y), max(x, y)));
        }
        printf("%I64d\n", solve());
    }
    return 0;
}
posted @ 2022-02-26 15:30  AE酱  阅读(153)  评论(0编辑  收藏  举报