牛客 11257 D Gambling Monster 题解

传送门


【大意】

初始时,"土块"有一个数字 \(0\) 。每一轮,他有 \(p_i(0\leq i<n=2^k)\) 的概率抽到数字 \(i\) 。若当前他的数字异或上抽中的数字,将会变得更大,那他会异或上这个数字。问他得到 \((n-1)\) 的期望步数。


【分析】

HL:以后遇到概率 dp 通通倒着跑

所以我们设 \(E(x)\) 表示 \(x\)\((n-1)\) 的期望步数,则有:

\(\displaystyle E(x)=\sum_{x\oplus y=z\\x<z}p_y[E(z)+1]+\sum_{x\oplus y=z\\x\geq z}p_y[E(x)+1]\)

即当转到的数字会使得结果更大,就异或上,所以贡献直接由新的结果转移而来;当不会时,就不异或上了,所以贡献由自己转移

为了方便,我们记 \(\displaystyle \sum_{x\oplus y=z\\x<z}p_y=S_x\),则:

\(\displaystyle E(x)=\sum_{x\oplus y=z\\x<z}p_y[E(z)+1]+(1-S_x)[E(x)+1]\)

\(\displaystyle E(x)-(1-S_x)E(x)=\sum_{x\oplus y=z\\x<z}p_yE(z)+S_x+(1-S_x)\)

\(\displaystyle E(x)={1\over S_x}[\sum_{x\oplus y=z\\x<z}p_yE(z)+1]\)

由于卷积是 \(x\oplus y=z\) 的形式,所以考虑使用 FWT

由于只考虑高维对低微的贡献,所以考虑使用 cdq 分治 FWT

而对于 \(S_x\) ,我们考虑 \(\displaystyle S_x=\sum_{x\oplus y=z\\x<z}p_y\)。当且仅当 \(y\) 的最高位在 \(x\) 中为 \(0\) 时,会使得 \(x\oplus y>x\)

我们对所有概率按最高位归纳,再对每个 \(x\) 按位枚举即可

复杂度为 \(O(n\log^2 n)+O(3^{\log_2 n})=O(n\log^2 n)\)


【代码】

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef pair<ll, ll> pii;
typedef double db;
#define fi first
#define se second
#define lowbit(x) ((x)&(-(x)))
const int MOD=1e9+7, MAXN=1<<16, inv2=MOD+1>>1;
inline int add(int a, int b) { return (a+=b)>=MOD?a-MOD:a; }
inline int dis(int a, int b) { return (a-=b)<0?a+MOD:a; }
inline ll fpow(ll a,ll x) { ll ans=1; for(;x;x>>=1,a=a*a%MOD) if(x&1) ans=ans*a%MOD; return ans; }
inline void FWT(int *a, int len, int o=1){
    for(int k=0; 1<<k<len; ++k) for(int i=0; i<len; ++i) if(~i>>k&1) {
        int j=i^(1<<k), x, y;
        x=add(a[i], a[j]), y=dis(a[i], a[j]);
        if(o==-1) x=(ll)x*inv2%MOD, y=(ll)y*inv2%MOD;
        a[i]=x, a[j]=y;
    }
}
inline void doit(int *a, int *b,int len) {
    FWT(a, len, 1); FWT(b, len, 1);
    for(int i=0;i<len;++i) a[i]=(ll)a[i]*b[i]%MOD;
    FWT(a, len, -1);
}
int a[MAXN], b[MAXN];

int n, p[MAXN], hb[MAXN], e[MAXN], sumit[16];
void solve(int l, int r) {
    if(l==r){
        int s=0;
        for(int i=0, x=~l; 1<<i<n; ++i)
            if( x>>i &1 )
                s=add(s, sumit[i]);
        e[l]=add(e[l], 1)*fpow(s, MOD-2)%MOD;
        return ;
    }
    int m=l+r>>1, len=r-l+1;
    solve(m+1, r);
    memcpy(a, e+l, len*sizeof(e[0]));
    memcpy(b, p, len*sizeof(p[0]));
    memset(a, 0, len*sizeof(a[0])>>1);
    doit(a, b, len);
    for(int i=l, j=0;i<=m; ++i, ++j)
        e[i]=add(e[i], a[j]);
    solve(l, m);
}
inline int ans(){
    cin>>n;
    int tot=0;
    memset(sumit, 0, sizeof(sumit));
    for(int i=0;i<n;++i) cin>>p[i], tot=add(tot, p[i]), e[i]=0;
    p[0]=p[0]*fpow(tot, MOD-2)%MOD;
    for(int i=1, x=fpow(tot, MOD-2); i<n; ++i){
        p[i]=(ll)p[i]*x%MOD;
        sumit[ hb[i] ]=add(sumit[ hb[i] ], p[i]);
    }
    solve(0, n-1);
    return e[0];
}
int main(){
    ios::sync_with_stdio(0);
    cin.tie(0); cout.tie(0);
    for(int i=2;i<1<<16;++i) hb[i]=hb[i-1]+(i==lowbit(i));
    int T; cin>>T;
    while(T--) cout<<ans()<<"\n";
    cout.flush();
    return 0;
}
posted @ 2021-08-04 16:46  JustinRochester  阅读(126)  评论(0编辑  收藏  举报