I-Increasing Subsequence 期望DP+逆元

题目传送门

题意

  给定一个长度为\(n\),元素为\(1,2,\dots,n\)的某种序列,\(Alice\)\(Bob\)轮流取数,要求每次取的数都要大于先前两个人曾选过的数,并且当前选择的数要在前一位玩家选择位置的右边,当前可以选择的数都将会等概率被选择,求出能够进行游戏轮次的期望值。

思路

  由于存在两个人选择,因此我们可以使用二维数组来分别表示两个人选择的位置,有两个人,所以可以用两个二维数组来分别表示是谁在做选择,当然也可以增加一维来表示。于是我们想到可以使用动态规划的方式来计算期望值。

 暴力计算\(O(N^3)\)

  假设我们分别使用\(g[i][j],f[i][j]\)表示当前是\(Alice选\)和当前是\(Bob\)选择,那么两个状态之间的转移就取决于当前能够选择数的个数,期望的贡献就等于从另一个状态转移到当前状态的期望总和。

  参考代码
点此展开
//Author:Daneii
//O(n^3)算法,必定超时
#include <bits/stdc++.h>

using namespace std;

#define in(x) scanf("%d",&x)
#define lin(x) scanf("%lld",&x)
#define din(x) scanf("%lf",&x)

typedef long long ll;
typedef long double ld;
typedef pair<int,int> PII;

const int N=5010;
const int mod=998244353;

int n;
ll a[N];
ll f[N][N],g[N][N];
ll inv[N];

//O(N)时间内预处理出1...n中每个数的逆元
void inv_init()
{
    inv[0]=inv[1]=1;
    for(int i=2;i<=n;i++)
        inv[i]=(mod-mod/i)*inv[mod%i]%mod;
}

int main()
{
    #ifdef LOCAL
    freopen("D:/VSCodeField/C_Practice/.input/a.in", "r", stdin);
    #endif

    in(n);
    inv_init();
    for(int i=1;i<=n;i++) lin(a[i]);

    for(int i=1;i<=n;i++)
    {
        for(int j=0;j<=n;j++)
        {
            if(i==j) continue;
            if(j==0)
                g[i][0]=inv[n];
            //枚举k
            //首先是现在是第一个人来选
            int cnt=0;//统计个数
            for(int k=j+1;k<=n;k++)
            {
                if(a[k]>a[i]&&a[k]>a[j])//满足每次选择的数都比先前的要大
                    cnt++;
            }
            //第一个人选择后转移到第二个人
            for(int k=j+1;k<=n;k++)
            {
                if(a[k]>a[i]&&a[k]>a[j])
                {
                    f[i][k]=(f[i][k]+inv[cnt]*g[i][j])%mod;
                }
            }

            //第二个人选
            cnt=0;
            for(int k=i+1;k<=n;k++)
            {
                if(a[k]>a[j]&&a[k]>a[i])
                {
                    cnt++;
                }
            }
            for(int k=i+1;k<=n;k++)
            {
                if(a[k]>a[j]&&a[k]>a[i])
                {
                    g[k][j]=(g[k][j]+inv[cnt]*f[i][j])%mod;
                }
            }
        }
    }

    ll ans=0;
	//将所有情况累加起来就是结果
    for(int i=1;i<=n;i++)
        for(int j=0;j<=n;j++)
            if(i==j) continue;
            else ans=((ans+f[i][j])%mod+g[i][j])%mod;
        
    printf("%lld\n",ans);

    return 0;
}

 优化计算\(O(N^2)\)

  在暴力算法中,我们发现每次都需要计算比当前选择数大的数的个数,也就是\(k\)的个数,由于每次都需要遍历一遍数组,增加了一维循环。我们考虑是否能够将这一维优化掉。由于\(k\)的个数,只与当前的数有关系,因此我们可以预处理一下\(k\)与每个数的大小关系,从而对于个数查询实现\(O(1)\)的查询。
  我们会发现,由于\(k\)必须要大于\(i\)或者\(j\),相当于每次都是看右边的数组,因此我们可以修改数组的定义,让\(f[i][j]\)\(Alice\)上一轮选择了\(i\),\(Bob\)上一轮选择了\(j\),那么\(i\)\(j\)位置上元素的大小关系就可以反应是谁在选择,而\(k\)的个数只需要我们从后向前计算,累计比当前\(i\)大和比当前\(j\)大的个数就可以实现\(O(1)\)的查询。
  现在考虑状态转移,如果当前是\(a[i]>a[j]\),说明当前\(Bob\)选,因此状态转移就是\(f[i][j]=1+\sum_{k>j,a[k]>a[i]}f[i][k]\),如果是\(a[i]<a[j]\),说明当前是\(Alice\)选,状态转移为\(f[i][j]=1+\sum_{k>i,a[k]>a[j]}f[k][j]\).

  参考代码
点此展开
//Author:Daneii
#include <bits/stdc++.h>

using namespace std;

#define in(x) scanf("%d",&x)
#define lin(x) scanf("%lld",&x)
#define din(x) scanf("%lf",&x)

typedef long long ll;
typedef long double ld;
typedef pair<int,int> PII;

const int N=5010;
const int mod=998244353;

int n;
ll p[N];
ll cnt[N];//用于统计比p[i]大的对应cnt[j]的个数
ll sum[N];//用于计算前缀和
ll inv[N];//预处理逆元
ll f[N][N];//状态定义:Alice上轮选择了i,Bob上轮选择了j
//因此状态转移要注意p[i]与p[j]之间的大小关系

void inv_init()
{
    inv[1]=1;
    for(int i=2;i<=n;i++)
    {
        inv[i]=(mod-mod/i)*inv[mod%i]%mod;
    }
}

int main()
{
    #ifdef LOCAL
    freopen("D:/VSCodeField/C_Practice/.input/a.in", "r", stdin);
    #endif

    in(n);
    inv_init();
    for(int i=1;i<=n;i++) lin(p[i]);

    for(int i=n;i>=1;i--)
    {
        ll bcnt=0,pre=0;
        for(int j=n;j>=0;j--)//j等于0即无法进行
        {
            if(i==j) continue;//不存在这样的局面
            if(p[i]>p[j])
            {
                //说明目前状态时从Alice选
                f[i][j]=((pre*inv[bcnt])%mod+1)%mod;
                sum[j]=(sum[j]+f[i][j])%mod;//统计此时比j大的前缀和
                cnt[j]++;//记录下比j大的个数
            }
            else
            {   //此时是Bob选
                f[i][j]=(sum[j]*inv[cnt[j]]%mod+1)%mod;
                pre=(pre+f[i][j])%mod;
                bcnt++;
            }
        }
    }

    ll ans=0;
    for(int i=1;i<=n;i++)
        ans=(ans+f[i][0])%mod;
    printf("%lld",ans*inv[n]%mod);


    return 0;
}

posted @ 2021-08-06 13:57  Daneii  阅读(45)  评论(0)    收藏  举报