[题解]CF1774D Same Count One

思路

首先记所有 \(1\) 的数量为 \(num\),那么显然有当 \(n \bmod num \neq 0\) 时无解。那么考虑有解的时候该怎么办。

显然对于每一个 \(a_i\) 序列中,最终 \(1\) 的数量为 \(\frac{num}{n}\),记作 \(t\);并记 \(cnt_i\) 表示 \(a_i\) 序列中 \(1\) 的数量。

我们希望最终所有的 \(cnt_i\) 都等于 \(t\),并且希望操作步数最小,我们考虑一个显然的贪心:将 \(cnt_i > t\) 的序列中的 \(1\)\(cnt_j < t\) 缺失的 \(1\)

这样我们每一次的操作都会使 \(\sum_{i = 1}^{n}|cnt_i - t|\) 减少 \(2\),显然是最优的方案。

注意:如果你在交换的时候,一定需要更新 \(a_{i,k}\)\(a_{j,k}\),否则有一个很简单的 Hack。因为你不更新,你的程序会认为 \(a_{3,1}\) 在第一次操作后还是 \(0\) 可以交换。

Code

#include <bits/stdc++.h>  
#define fst first  
#define snd second  
#define re register  
  
using namespace std;  
  
typedef pair<int,int> pii;  
const int N = 1e5 + 10;  
int n,m;  
int cnt[N];  
pii del[N];  
  
struct answer{  
    int a,b,pos;  
};  
  
inline int read(){  
    int r = 0,w = 1;  
    char c = getchar();  
    while (c < '0' || c > '9'){  
        if (c == '-') w = -1;  
        c = getchar();  
    }  
    while (c >= '0' && c <= '9'){  
        r = (r << 3) + (r << 1) + (c ^ 48);  
        c = getchar();  
    }  
    return r * w;  
}  
  
inline void solve(){  
    int num = 0;  
    n = read();  
    m = read();  
    bool vis[n + 10][m + 10];  
    vector<answer> ans;  
    for (re int i = 1;i <= n;i++){  
        cnt[i] = 0;  
        for (re int j = 1;j <= m;j++){  
            int x;  
            x = read();  
            if (x) vis[i][j] = true;  
            else vis[i][j] = false;  
            num += x;  
            cnt[i] += x;  
        }  
    }  
    if (num % n) return puts("-1"),void();  
    num /= n;  
    for (re int i = 1;i <= n;i++) del[i] = {cnt[i] - num,i};  
    sort(del + 1,del + n + 1);  
    for (re int i = 1,j = n;i < j;){  
        int p = del[i].snd,q = del[j].snd;  
        for (re int k = 1;k <= m && del[i].fst && del[j].fst;k++){  
            if (!vis[p][k] && vis[q][k]){  
                del[i].fst++;  
                del[j].fst--;  
                vis[p][k] = true;  
                vis[q][k] = false;  
                ans.push_back({p,q,k});  
            }  
        }  
        if (!del[i].fst) i++;  
        if (!del[j].fst) j--;  
    }  
    printf("%d\n",ans.size());  
    for (auto p:ans) printf("%d %d %d\n",p.a,p.b,p.pos);  
}  
  
int main(){  
    int T;  
    T = read();  
    while (T--) solve();  
    return 0;  
}  
posted @ 2024-06-25 12:26  WBIKPS  阅读(30)  评论(0)    收藏  举报