bzoj 3027 [Ceoi2004] Sweet —— 生成函数

题目:https://www.lydsy.com/JudgeOnline/problem.php?id=3027

就是 (1+x+x2+...+xm[i]) 乘起来;

原来想和背包一样做,然而时限很短,数组也开不了很多,本来以为勉强一下也可以,后来突然发现不行...

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
int const xn=1e7+5,mod=2004;
int n,a,b,s[xn],m;
int rd()
{
  int ret=0,f=1; char ch=getchar();
  while(ch<'0'||ch>'9'){if(ch=='-')f=0; ch=getchar();}
  while(ch>='0'&&ch<='9')ret=(ret<<3)+(ret<<1)+ch-'0',ch=getchar();
  return f?ret:-ret;
}
int upt(int x){while(x>=mod)x-=mod; while(x<0)x+=mod; return x;}
int main()
{
  n=rd(); a=rd(); b=rd(); m=rd();
  for(int i=0;i<=m;i++)s[i]=i+1;
  for(int i=m+1;i<=b;i++)s[i]=s[m];//
  for(int i=2;i<=n;i++)
    {
      m=rd();
      //mx=min(mx+m[i],b);
      for(int j=b;j>=0;j--)//
      if(j-m-1>=0)s[j]=upt(s[j]-s[j-m-1]);
      for(int j=1;j<=b;j++)s[j]=upt(s[j]+s[j-1]);
    }
  int ans=s[b];
  if(a)ans=upt(ans-s[a-1]);
  printf("%d\n",ans);
  return 0;
}
TLE

首先,要化简这个多项式,得到 ∏(1-xm[i]+1) / (1-x)n

可以把分子和分母分开,分母就是熟悉的 ∑ C(n+i-1,n-1)*xi

而分子一共只有 n 项,可以 2n 搜出每个系数;

然后把二者组合在一起,对于搜出的 k * xy ,对答案有贡献还需要把 xy 变成 xa ~ xb

所以对应分母多项式的 xa-y ~ xb-y 的系数,是连续的组合数求和,杨辉三角里的一列;

但是模数不是质数,所以组合数不好算;

参考TJ,竟然可以对组合数和模数都乘 n!,就可以 O(n) 直接乘得到组合数了,最后把答案除以 n! 即可;

如果把搜到的系数存下来,最后遍历,复杂度反而成了 O(bn) ... 不如直接在搜索里计算,有值才算上,复杂度 O(n*2n)。

代码如下:

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
typedef long long ll;
int const xn=15,xm=1e7+5,mod=2004;
int n,a,b,m[xn];
ll fac,p,ans;
int rd()
{
  int ret=0,f=1; char ch=getchar();
  while(ch<'0'||ch>'9'){if(ch=='-')f=0; ch=getchar();}
  while(ch>='0'&&ch<='9')ret=(ret<<3)+(ret<<1)+ch-'0',ch=getchar();
  return f?ret:-ret;
}
ll upt(ll x){while(x>=mod)x-=mod; while(x<0)x+=mod; return x;}
ll C(int n,int m)
{
  if(n<m)return 0;//!
  ll ret=1;
  for(int i=n-m+1;i<=n;i++)ret=(ret*i)%p;
  return ret;
}
void dfs(int x,int s,int t)
{
  if(x==n+1)
    {
      ans+=s*(C(n+b-t,n)-C(n+a-t-1,n)); ans=ans%p;
      return;
    }
  dfs(x+1,s,t);
  dfs(x+1,-s,t+m[x]+1);
}
int main()
{
  n=rd(); a=rd(); b=rd(); fac=1;
  for(int i=1;i<=n;i++)m[i]=rd(),fac*=i;
  p=(ll)fac*mod;
  dfs(1,1,0);
  /*
  for(int y=0;y<=b;y++)
    {
      ll tmp=upt(C(n+b-y,n)-C(n+a-y-1,n));
      ans=(ans+tmp*f[y])%p;
    }
  */
  if(ans<0)ans+=p;
  printf("%lld\n",ans/fac);
  return 0;
}

 

posted @ 2018-11-27 20:50  Zinn  阅读(180)  评论(0编辑  收藏  举报