LOJ 2304 「NOI2017」泳池——思路+DP+常系数线性齐次递推

 

题目:https://loj.ac/problem/2304

看了各种题解……

\( dp[i][j] \) 表示有 i 列、第 j 行及以下默认合法,第 j+1 行至少有一个非法格子的概率,满足最大合法矩形面积 <= lm。其中第 j 行及以下的部分的贡献是 1 而不是 q 的几次方。

那么有 \( dp[i][j]=dp[i][j+1]*p^i + \sum\limits_{k=1}^{i}dp[k-1][j+1]*p^{k-1}*(1-p)*dp[i-k][j] \)

注意到当 i>k 的时候,最底下一行必然有至少一个位置是非法的。所以令 \(ans_i\) 表示 i 列的概率,有 \( ans_i = \sum\limits_{j=1}^{i}ans_{j-1}*(1-p)*dp[i-j][1]*p^{i-j} \)

\(ans_i\) 的初值就是 dp[i][0] 。注意 dp[0][*]=1 。然后可以用常系数线性齐次递推的知识优化。

注意清空数组。注意别把 n 的值真的改掉。

#include<cstdio>
#include<cstring>
#include<algorithm>
#define ll long long
using namespace std;
const int N=1005,M=N<<1,mod=998244353;
int upt(int x){while(x>=mod)x-=mod;while(x<0)x+=mod;return x;}
int pw(int x,int k)
{int ret=1;while(k){if(k&1)ret=(ll)ret*x%mod;x=(ll)x*x%mod;k>>=1;}return ret;}

int n,q,q2,f[N][N],bin[N],a[N],ans[M],b[M],c[M],lm;
void Mul(int *u,int *v)//(lm-1)'
{
  memset(c,0,sizeof c);
  for(int i=0;i<lm;i++)
    for(int j=0;j<lm;j++)
      c[i+j]=(c[i+j]+(ll)u[i]*v[j])%mod;

  for(int i=2*(lm-1);i>=lm;i--)
    if(c[i])
      for(int j=1;j<=lm;j++)
    c[i-j]=(c[i-j]+(ll)c[i]*a[j])%mod;
  memcpy(u,c,sizeof 4*lm);//0~lm-1
}
int solve(int tmp)
{
  lm=tmp; memset(f,0,sizeof f);
  for(int j=0;j<=lm+1;j++)f[0][j]=1;//lm+1 not lm!!!
  for(int i=1;i<=lm;i++)
    for(int j=lm/i;j>=0;j--)
      {
    int tp=(ll)f[i][j+1]*bin[i]%mod;
    for(int k=1;k<=i;k++)
      {
        int ml=(ll)f[k-1][j+1]*f[i-k][j]%mod;
        ml=(ll)ml*q2%mod*bin[k-1]%mod;
        tp=upt(tp+ml);
      }
    f[i][j]=tp;
      }
  if(n<=lm)return f[n][0]; lm++;

  for(int i=1;i<=lm;i++)
    {
      int tp=(ll)f[i-1][1]*bin[i-1]%mod;
      a[i]=(ll)tp*q2%mod;//not lm-i
    }
  memset(ans,0,sizeof ans);////
  memset(b,0,sizeof b);////
  ans[0]=b[1]=1; int tn=n;//////
  while(tn)
    {
      if(tn&1)Mul(ans,b); Mul(b,b); tn>>=1;
    }
  int ret=0;
  for(int i=0;i<lm;i++)
    ret=(ret+(ll)ans[i]*f[i][0])%mod;
  return ret;
}
int main()
{
  int x,y,k;scanf("%d%d%d%d",&n,&k,&x,&y);
  q=(ll)x*pw(y,mod-2)%mod; q2=upt(1-q);
  bin[0]=1;
  for(int i=1;i<=k;i++)bin[i]=(ll)bin[i-1]*q%mod;
  printf("%d\n",upt(solve(k)-solve(k-1)));
  return 0;
}

 

posted on 2019-06-09 12:03  Narh  阅读(146)  评论(0编辑  收藏  举报

导航