cf1637 E. Best Pair

题意:

给定数组。定义函数 \(f(x,y)=(cnt _x + cnt_y)\cdot (x+y)\),其中 \(cnt_x\) 表示值 \(x\) 在原数组中的出现次数

给定一些 bad_pairs \((x',y')\),表示不能选 \((x',y')\)\((y',x')\) 。除此之外,所有 \(x'=y'\) 也不能选

\(f\) 的最大值

思路:

突破口是 \(cnt_x\) 的不同取值在 \(\sqrt n\) 数量级

遍历 $ cnt_x$ 和 \(cnt_y\) ,对于确定的 $ cnt_x$ 和 \(cnt_y\) ,找 \((x+y)\) 最大的非 bad 的 pair。

可以用 优先队列+bfs 找,一提交发现要 1200ms

int n, m;
void sol() {
    map<int, int> c;
    cin >> n >> m; for(int i = 1, x; i <= n; i++) cin >> x, c[x]++;
    map<PII, bool> bad;
    while(m--) {
        int x, y; cin >> x >> y;
        bad[{x,y}] = bad[{y,x}] = 1;
    }

    map<int, vector<int>> mp;
    for(auto &[val,cnt]: c) mp[cnt].pb(val); //val在c里面就是有序的

    ll ans = 0;
    for(auto &[cntx,vex]: mp)
    for(auto &[cnty,vey]: mp) {
        if(cntx < cnty) continue; //稍微加速一下

        auto cmp = [&](PII &a, PII &b) { //帅气的重载
            return vex[a.fi] + vey[a.se] < vex[b.fi] + vey[b.se];
        };
        priority_queue<PII, vector<PII>, decltype(cmp)> q(cmp);

        q.push({vex.size()-1,vey.size()-1});
        map<PII, bool> vis;

        while(q.size()) {
            auto [i,j] = q.top(); q.pop();
            int x = vex[i], y = vey[j];
            if(vis[{x,y}]) continue; vis[{x,y}] = true;
            if(x != y && !bad[{x,y}]) {
                ans = max(ans, (ll)(cntx + cnty) * (x + y));
                break;
            }
            if(i) q.push({i-1,j}); if(j) q.push({i,j-1});
        }
    }

    cout << ans << endl;
}

1200ms没法忍啊,于是看一眼标程,震惊到我!

他居然完全没用优先队列,直接疯狂遍历,只用500ms。。。

几个循环的顺序非常奇妙,稍微改一下就会tle。。。

来自官方题解评论区的tips:注意for(cnty)那里,从1遍历到cntx复杂度就是对的,从cntx遍历到n的话复杂度就会变成 n*sqrt(n) (真神奇)

进一步解释:对每一个cntx,遍历cnty,则每次的cnty最多有cntx个,一共就 \(\sum cnt_x\) 次遍历,即 \(n\)

然后只有对非空的vecx才要进一步遍历y,所以那两个for的顺序不能改

然后尽量用vector+二分代替map,省很多空间,常数应该也会更小

int n, m;
void sol() {
    map<int, int> c;
    cin >> n >> m; for(int i = 1, x; i <= n; i++) cin >> x, c[x]++;
    vector<PII> bad;
    while(m--) {
        int x, y; cin >> x >> y;
        bad.pb({x,y}), bad.pb({y,x});
    }
    sort(all(bad));

    vector<vector<int>> mp(n);
    for(auto &[val,cnt]: c) mp[cnt].pb(val); //val在c里面就是有序的

    for(auto &ve: mp) reverse(all(ve)); //从大到小

    ll ans = 0;
    for(int cntx = 1; cntx < n; cntx++)
    for(int x : mp[cntx]) //这一行和下一行交换,会tle
        for(int cnty = 1; cnty <= cntx; cnty++) //从cntx遍历到n会tle
        for(int y : mp[cnty])
            if(x != y && !binary_search(all(bad), PII(x,y))) {
                ans = max(ans, (ll)(cntx + cnty) * (x + y));
                break;
            }

    cout << ans << endl;
}
posted @ 2022-05-16 22:25  Bellala  阅读(36)  评论(0)    收藏  举报