AT_abc241_h题解
题面:有一些卡牌,每张卡牌上有一个数字,具体的,有 \(b_i\) 张卡牌上的数字为 \(a_i\)。
求出拿走其中 \(m\) 张卡牌的贡献之和。贡献为这些卡牌的乘积。对于本质相同的卡牌组合,只算一次。
- \(n\leq 16,m\leq 10^{18},b_i\leq 10^{17},1\leq a_i<mod\)
\(sol\) :
首先对于每种卡牌构造生成函数:
\[\sum\limits_{j=0}^{b_i} a_i^jx^j=\frac{1-a_i^{b_i+1}x^{b_i+1}}{1-a_ix}
\]
最终答案的就是:
\[\begin{aligned}
[x^m]\prod\limits_{i=1}^n \frac{1-a_i^{b_i+1}x^{b_i+1}}{1-a_ix}\\
=[x^m]\frac{\prod\limits_{i=1}^n1-a_i^{b_i+1}x^{b_i+1}}{\prod\limits_{i=1}^n1-a_ix}
\end{aligned}
\]
这里因为 \(n\) 很小,只有 \(16\) 所以可以\(2^n\)暴力把分子部分展开。假设展开有 \(w\) 项,第 \(i\) 项为 \(c_ix^{d_i}\) ,则累加每一项的贡献得到的答案就是:
\[\begin{aligned}
\sum\limits_{i=1}^wc_i\times[x^{n-d_i}]\frac{1}{\prod\limits_{i=1}^n1-a_ix}
\end{aligned}
\]
现在问题变成了 \(2^n\) 次询问分母某一位的值。处理分母可以考虑待定系数,如果可以写成求和的形式就好做多了,于是我们令:
\[\begin{aligned}
\prod\limits_{i=1}^n\frac1{1-a_ix}=\sum\limits_{i=1}^n \frac{p_i}{1-a_ix}\\
\sum\limits_{i=1}^n \frac{p_i}{1-a_ix}\prod\limits_{j=1}^n 1-a_jx=1\\
\sum\limits_{i=1}^n p_i\prod\limits_{j=1,j\ne i}^n 1-a_jx=1\\
\end{aligned}
\]
一般来说生成函数中的 \(x\) 是没有实际意义的,但是这里我们可以考虑把它当作一个普通的函数,而普通函数有一个优势,可以代值求待定系数。所以我们代入 \(x=\frac1{a_k}\) 那么当 \(j=k\) 时,式子中 \(1-a_jx=0\) ,也就是说只有 \(i=k\) 这一项被保留了下来:
\[\begin{aligned}
p_i\prod\limits_{j=1,j\ne i}^n 1-\frac{a_j}{a_i}=1\\
p_i=\frac1{\prod\limits_{j=1,j\ne i}^n 1-\frac{a_j}{a_i}}
\end{aligned}
\]
于是我们可以求出每一项的待定系数,现在再看一下分母的形式就非常好求了:
\[\begin{aligned}
[x^k]\sum\limits_{i=1}^n \frac{p_i}{1-a_ix}\\
=\sum\limits_{i=1}^n[x^k]\frac{p_i}{1-a_ix}\\
=\sum\limits_{i=1}^np_i\times a_i^k
\end{aligned}
\]
至此,我们得到了一个单次 \(O(n)\) 查询某一位的做法。
最终复杂度:\(O(n2^n)\)
code:
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const ll mod=998244353,N=1e6+5;
inline ll rd()
{
char c;ll f=1;
while(!isdigit(c=getchar()))if(c=='-')f=-1;
ll x=c-'0';
while(isdigit(c=getchar()))x=x*10+(c^48);
return x*f;
}
ll qp(ll x,ll y)
{
ll res=1;
while(y)
{
if(y&1)res=res*x%mod;
x=x*x%mod;
y>>=1;
}
return res;
}
ll n,m,a[20],in,b[20],x[20],p[20],nx[N],P[N],w=1;
int main()
{
n=rd(),m=rd();
for(int i=1;i<=n;i++)
a[i]=rd(),b[i]=rd(),x[i]=-qp(a[i],b[i]+1);
nx[w]=1,P[w]=0;
for(int i=1;i<=n;i++)
{
int t=w;
for(int j=1;j<=t;j++)
{
nx[++w]=nx[j]*x[i]%mod;
P[w]=P[j]+b[i]+1;
}
}
for(int i=1;i<=n;i++)
{
p[i]=1,in=qp(a[i],mod-2);
for(int j=1;j<=n;j++) if(i!=j)
p[i]=p[i]*qp(1-in*a[j]%mod+mod,mod-2)%mod;
}
ll ans=0;
for(int i=1;i<=w;i++)
{
if(P[i]>m) continue;
ll now=m-P[i];
ll res=0;
for(int j=1;j<=n;j++)
(res+=(p[j]*qp(a[j],now)%mod))%=mod;
(ans+=(nx[i]*res%mod))%=mod;
}
cout<<ans;
return 0;
}

浙公网安备 33010602011771号