[luogu2503][HAOI2006]均分数据【模拟退火】

题目描述

已知N个正整数:A1、A2、……、An 。今要将它们分成M组,使得各组数据的数值和最平均,即各组的均方差最小。均方差公式如下:

分析

模拟退火学习笔记:https://www.cnblogs.com/chhokmah/p/10529114.html
万物皆可颓火,我们首先将初始的答案当做一半一半的答案,然后我们随机化抽取两个部分的数据。根据题目中的描述,因为两个组别之间数据个数只能是差一,那么差不多就是一半一半的情况。那么我们就只需要分块两部分,然后随机交换,如果两个数据交换之后能使答案能更优,那么就交换,如果不能让我们的答案变得更加优,那么就让随机概率,这个概率很明显是越到后面交换的概率越小,那么我们就是exp(delta) < t * Rand(),那么就交换,否则就不交换。
模拟退火的精髓还是这个调参,这道题目我一遍A掉了,感觉有一点欧皇。
我给出一个比较优秀的随机种子,是ouhuang和6666666的取模,就是15346301。

ac代码

#include <bits/stdc++.h>
#define ms(a,b) memset(a, b, sizeof(a))
#define db double 
using namespace std;
inline char gc() {
    static char buf[1 << 16], *S, *T;
    if (S == T) {
        T = (S = buf) + fread(buf, 1, 1 << 16, stdin);
        if (T == S) return EOF;
    }
    return *S ++;
}
template <typename T>
inline void read(T &x) {
    T w = 1;
    x = 0;
    char ch = gc();
    while (ch < '0' || ch > '9') {
        if (ch == '-') w = -1;
        ch = gc();
    }
    while (ch >= '0' && ch <= '9') x = (x << 1) + (x << 3) + (ch ^ 48), ch = gc();
    x = x * w;
}
template <typename T>
void write(T x) {
    if (x < 0) putchar('-'), x = -x;
    if (x > 9) write(x / 10);
    putchar(x % 10 + 48);
}
#define N 305
db ans = 1e30, ave = 0;
int sum[N], pos[N], a[N];
int n, m;
void SA(db T){
    ms(sum, 0);
    for (int i = 1; i <= n; i ++) {
        pos[i] = rand() % m + 1;
        sum[pos[i]] += a[i];
    }
    db res = 0;
    for (int i = 1; i <= m; i ++) 
        res += (1.0 * sum[i] - ave) * (1.0 * sum[i] - ave);
    while (T > 1e-4) {
        int t = rand() % n + 1, x = pos[t], y;
        if (T > 500) y = min_element(sum + 1, sum + 1 + m) - sum;
        else y = rand() % m + 1;
        if (x == y) continue;
        db tmp = res;
        res -= (sum[x] - ave) * (sum[x] - ave);
        res -= (sum[y] - ave) * (sum[y] - ave);
        sum[x] -= a[t], sum[y] += a[t];
        res += (sum[x] - ave) * (sum[x] - ave);
        res += (sum[y] - ave) * (sum[y] - ave);
        if (res < tmp || rand() % 10000 <= T) pos[t] = y;
        else sum[x] += a[t], sum[y] -= a[t], res = tmp;
        ans = min(ans, res);
        T *= 0.98;
    }
}
int main() {
    srand(20040127);
    read(n); read(m);
    for (int i = 1; i <= n; i ++) {
        read(a[i]);
        ave += 1.0 * a[i];
    }
    ave /= 1.0 * m;
    for (int i = 1; i <= 1500; i ++) SA(10000);
    printf("%.2lf\n", sqrt(ans / m));
    return 0;
}
posted @ 2019-03-14 13:06 chhokmah 阅读(...) 评论(...) 编辑 收藏