【cf1173E】Nauuo and Pictures(概率dp)

传送门

题意:
给出\(n\)个数,每个数有一个权值\(w_i\)和所属集合\(a_i,a_i=0,1\)
现在执行\(m\)次以下操作:

  • 随机选择一个数,每个数选择的概率为\(\displaystyle p=\frac{w_i}{sum_w}\),若\(a_i=1\),那么权值加一;否则减一;

问最后每个数的期望权值为多少。

思路:
先说以下\(easy\)版本:
\(easy\)版本的限制为\(n,m\leq 50\)。那么这就很好做了。我们直接对每个数单独考虑,然后跑\(O(n^3)\)\(dp\)
\(dp\)定义如下:\(dp_{i,j,k}\)表示考虑了前\(i\)次操作,有\(j\)次操作为当前所在集合,其中有\(k\)次操作于当前数的概率。
求出来之后根据概率直接算期望即可。
代码如下:

Code
/*
 * Author:  heyuhhh
 * Created Time:  2020/3/22 9:18:57
 */
#include <iostream>
#include <algorithm>
#include <cstring>
#include <vector>
#include <cmath>
#include <set>
#include <map>
#include <queue>
#include <iomanip>
#include <assert.h>
#define MP make_pair
#define fi first
#define se second
#define pb push_back
#define sz(x) (int)(x).size()
#define all(x) (x).begin(), (x).end()
#define INF 0x3f3f3f3f
#define Local
#ifdef Local
  #define dbg(args...) do { cout << #args << " -> "; err(args); } while (0)
  void err() { std::cout << '\n'; }
  template<typename T, typename...Args>
  void err(T a, Args...args) { std::cout << a << ' '; err(args...); }
  template <template<typename...> class T, typename t, typename... A> 
  void err(const T <t> &arg, const A&... args) {
  for (auto &v : arg) std::cout << v << ' '; err(args...); }
#else
  #define dbg(...)
#endif
using namespace std;
typedef long long ll;
typedef pair<int, int> pii;
//head
const int N = 50 + 5, MOD = 998244353;
int qpow(ll a, ll b) {
    ll res = 1;
    while(b) {
        if(b & 1) res = res * a % MOD;
        a = a * a % MOD;
        b >>= 1;   
    }
    return res;   
}
int fac[N], inv[N];
void init() {
    fac[0] = 1;
    for(int i = 1; i < N; i++) fac[i] = 1ll * fac[i - 1] * i % MOD;
    inv[N - 1] = qpow(fac[N - 1], MOD - 2);
    for(int i = N - 2; i >= 0; i--) inv[i] = 1ll * inv[i + 1] * (i + 1) % MOD;
}

int n, m;
int dp[N][N][N];
int a[N], w[N];

void add(int &x, int y) {
    x += y;
    if(x >= MOD) x -= MOD;
}

int solve(int t) {
    memset(dp, 0, sizeof(dp));
    dp[0][0][0] = 1;
    int sum = 0, sa = 0, sb = 0;
    for(int i = 1; i <= n; i++) {
        sum += w[i];
        if(a[i] > 0) sa += w[i]; else sb += w[i];
    }
    for(int i = 0; i < m; i++) {
        for(int j = 0; j <= i; j++) {
            for(int k = 0; k <= j; k++) {
                int tsum = sum + a[t] * (j - (i - j));
                int now = w[t] + a[t] * k;
                int others = (a[t] > 0 ? sa + j - now : sb - j - now);
                add(dp[i + 1][j][k], 1ll * dp[i][j][k] * (tsum - others - now) % MOD * qpow(tsum, MOD - 2) % MOD);
                add(dp[i + 1][j + 1][k + 1], 1ll * dp[i][j][k] * now % MOD * qpow(tsum, MOD - 2) % MOD);
                add(dp[i + 1][j + 1][k], 1ll * dp[i][j][k] * others % MOD * qpow(tsum, MOD - 2) % MOD);
            }
        }   
    }
    int res = 0;
    for(int j = 0; j <= m; j++) {
        for(int k = 0; k <= j; k++) {
            int tt = w[t] + a[t] * k;
            res = (res + 1ll * tt * dp[m][j][k] % MOD) % MOD;
        }   
    }
    cout << res << '\n';
}

void run() {
    cin >> n >> m;
    for(int i = 1; i <= n; i++) {
        cin >> a[i];
        if(a[i] == 0) a[i] = -1;
    }
    for(int i = 1; i <= n; i++) cin >> w[i];
    for(int i = 1; i <= n; i++) solve(i);
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(0); cout.tie(0);
    cout << fixed << setprecision(20);
    init();
    run();
    return 0;
}

然后就是\(hard\)版本:
\(hard\)版本的限制为\(n\leq 10^6,m\leq 10^3\),显然我们上面的做法时间复杂度完全不能承受。但总操作数比较小,我们可以从总操作数方面考虑。
考虑一个简单情况:一个集合中如果所有数初始权值相同的话,那么如果最后总和的期望为\(s\),那么每个数的期望为\(\displaystyle\frac{s}{c},c\)为集合元素个数。
那么对于权值为\(a_i\)的数,我们可以拆分为\(a_i\)\(1\),那么现在所有数的权值相等。用\(E(x)\)表示权值为\(x\)时的期望的话,\(E(1)\)则可以通过最终总和的期望计算出来,那么\(E(x)=xE(1)\)就可以直接计算。
那么我们将两个集合当作两个数,通过简单\(O(m^2)dp\)计算出每个数最终的期望,然后按初始权值进行分配即可。
至于详细的证明可参见官方中文题解

一点题外话:
这个题最终期望按照权值分配这一点其实一直感觉有点怪怪的,下午也一直在尝试有没有什么比较好的思考方法,但并没有什么其它的思路,倒是发现了每次选择一个数的概率都为\(\displaystyle \frac{w_i}{sum_w}\),这和期望均匀分配又有什么关系呢。
如果有大佬知道的话希望能够不吝赐教。
代码如下:

Code
/*
 * Author:  heyuhhh
 * Created Time:  2020/3/22 19:30:09
 */
#include <iostream>
#include <algorithm>
#include <cstring>
#include <vector>
#include <cmath>
#include <set>
#include <map>
#include <queue>
#include <iomanip>
#include <assert.h>
#define MP make_pair
#define fi first
#define se second
#define pb push_back
#define sz(x) (int)(x).size()
#define all(x) (x).begin(), (x).end()
#define INF 0x3f3f3f3f
#define Local
#ifdef Local
  #define dbg(args...) do { cout << #args << " -> "; err(args); } while (0)
  void err() { std::cout << '\n'; }
  template<typename T, typename...Args>
  void err(T a, Args...args) { std::cout << a << ' '; err(args...); }
  template <template<typename...> class T, typename t, typename... A> 
  void err(const T <t> &arg, const A&... args) {
  for (auto &v : arg) std::cout << v << ' '; err(args...); }
#else
  #define dbg(...)
#endif
using namespace std;
typedef long long ll;
typedef pair<int, int> pii;
//head
const int N = 2e5 + 5, M = 3005, MOD = 998244353;
int qpow(ll a, ll b) {
    ll res = 1;
    while(b) {
        if(b & 1) res = res * a % MOD;
        a = a * a % MOD;
        b >>= 1;   
    }
    return res;   
}

int n, m;
int a[N], w[N];
int dp[M][M];

void add(int &x, int y) {
    x += y;
    if(x >= MOD) x -= MOD;   
}

void run() {
    cin >> n >> m;
    for(int i = 1; i <= n; i++) cin >> a[i];
    for(int i = 1; i <= n; i++) cin >> w[i];
    int sa = 0, sb = 0;
    for(int i = 1; i <= n; i++) {
        if(a[i]) sa += w[i];
        else sb += w[i];   
    }
    dp[0][0] = 1;
    for(int i = 0; i < m; i++) {
        for(int j = 0; j <= i; j++) {
            int sum = sa + j + sb - (i - j);
            add(dp[i + 1][j + 1], 1ll * dp[i][j] * (sa + j) % MOD * qpow(sum, MOD - 2) % MOD);
            add(dp[i + 1][j], 1ll * dp[i][j] * (sb - (i - j)) % MOD * qpow(sum, MOD - 2) % MOD);
        }   
    }
    int ea = 0, eb = 0;
    for(int j = 0; j <= m; j++) {
        add(ea, 1ll * dp[m][j] * (sa + j) % MOD);
        add(eb, 1ll * dp[m][j] * (sb - (m - j)) % MOD);
    }
    sa = qpow(sa, MOD - 2), sb = qpow(sb, MOD - 2);
    for(int i = 1; i <= n; i++) {
        int res;
        if(a[i]) res = 1ll * ea * sa % MOD * w[i] % MOD;
        else res = 1ll * eb * sb % MOD * w[i] % MOD;
        cout << res << '\n';   
    }
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(0); cout.tie(0);
    cout << fixed << setprecision(20);
    run();
    return 0;
}
posted @ 2020-03-22 21:17  heyuhhh  阅读(233)  评论(0编辑  收藏  举报