[PKUWC2018]随机游走——Min-Max容斥+DP+FWT

题面

   LOJ#2542

解析

  求到$S$集合每个点走一次的期望,即求$E(max(S))$,套上$Min-Max$容斥,即是求$E(min(T)),T\subseteq S$

  考虑对每种集合做一次$dp$,外层枚举$P \subseteq U$,$dp[u]$表示点$u$到$P$集合内任意一点的期望时间, $deg[u]$表示点$u$的度数,$fa$为$u$的父亲节点,$v$为$u$的儿子节点。

  若$u \in P$,则$dp[u] = 0$

  否则有:$dp[u] = \frac{1}{deg[u]}(dp[fa]+\sum dp[v])+1$

  然后是一个我没总结过的较常见套路,设$dp[u] = A[u] * dp[fa] + B[u]$,带入上式化简:$$deg[u]*dp[u]=dp[fa]+\sum(A[v]*dp[u]+B[v])+deg[u]$$$$(deg[u]-\sum A[v])*dp[u]=dp[fa]+(\sum B[v])+deg[u]$$$$dp[u]=\frac{1}{deg[u]-\sum A[v]}*dp[fa]+\frac{deg[u]+\sum B[v]}{deg[u]-\sum A[v]}$$

  故:$$A[u]=\frac{1}{deg[u]-\sum A[v]},\ B[u]=\frac{deg[u]+\sum B[v]}{deg[u]-\sum A[v]}$$

  对于根节点,由于其没有父节点,故$dp[u]=B[u]$,也即$B[u]$为所求。

  现在可以求出$E(min(T))$,但我们需要求出$\sum_{T\subseteq S}(-1)^{|T|+1}*E(min(T))$,可以发现其实就是求$S$的子集权值和, 可以用$FWT(or)$预处理出所有$S$的答案,每次询问可以做到$O(1)$回答。

  因$DP$过程中需要求逆元,故时间复杂度为$DP$的时间复杂度:$O(n2^n \log mod)$

 代码:

#include<cstdio>
#include<iostream>
#include<algorithm>
#include<cstring>
#include<vector>
using namespace std;
typedef long long ll;
const int maxn = (1 << 18) + 5, mod = 998244353;

ll qpow(ll x, ll y)
{
    ll ret = 1;
    while(y)
    {
        if(y&1)
            ret = ret * x % mod;
        x = x * x % mod;
        y >>= 1;
    }
    return ret;
}

ll add(ll x, ll y)
{
    return x + y < mod? x + y: x + y - mod;
}

ll rdc(ll x, ll y)
{
    return x - y < 0? x - y + mod: x - y;
}

int n, m, Q, rt, deg[20], num[maxn];
ll f[maxn], A[20], B[20];
vector<int> G[maxn];

void dfs(int x, int fa, int s)
{
    if((s >> (x - 1)) & 1)
    {
        A[x] = B[x] = 0;
        return ;
    }
    ll s1 = 0, s2 = 0;
    for(auto &id: G[x])
    {
        if(id == fa)    continue;
        dfs(id, x, s);
        s1 = add(s1, A[id]);
        s2 = add(s2, B[id]);
    }
    A[x] = qpow(rdc(deg[x], s1), mod - 2);
    B[x] = add(s2, deg[x]) * A[x] % mod;
}

void FWT(ll *x)
{
    for(int i = 1; i <= m; i <<= 1)
        for(int j = 0; j <= m; j += (i << 1))
            for(int k = 0; k < i; ++k)
                x[i+j+k] = add(x[i+j+k], x[j+k]);
}

int main()
{
    scanf("%d%d%d", &n, &Q, &rt);
    int u, v, cnt, sta;
    for(int i = 1; i < n; ++i)
    {
        scanf("%d%d", &u, &v);
        G[u].push_back(v);
        G[v].push_back(u);
        ++ deg[u];
        ++ deg[v];
    }
    m = (1 << n) - 1;
    for(int i = 1; i <= m; ++i)
    {
        dfs(rt, 0, i);
        num[i] = num[i>>1] + (i & 1);
        f[i] = ((num[i] & 1)? 1: mod - 1) * B[rt] % mod;
    }
    FWT(f);
    for(int i = 1; i <= Q; ++i)
    {
        scanf("%d", &cnt);
        sta = 0;
        for(int j = 1; j <= cnt; ++j)
        {
            scanf("%d", &u);
            sta |= (1 << (u - 1));
        }
        printf("%lld\n", f[sta]);
    }
    return 0;
}
View Code

 

posted @ 2020-03-02 19:00  Mr_Joker  阅读(134)  评论(0编辑  收藏  举报