好题。
题目分析:
首先,总方案数有 \(n^k\) 种,每种的概率都是均等的,为 \(\frac{1}{n^k}\),即 \(E=\frac{1}{n^k}\sum w\),而我们最后要输出 \(n^kE\),那么只用求 \(\sum w\),即所有方案的 \(ans\) 总和。
如果没有修改,那么所求即为 \(kn^{k-1}\sum a\),这是因为 \(n^k\) 种方案,每个方案有 \(k\) 个数,所以共有 \(kn^k\) 个数。而 \(a\) 中 \(n\) 个数的出现次数应该相等,所以每个数出现了 \(kn^{k-1}\) 次。
本题的修改很有意思,可以观察到,一个 \(a_i\) 经过若干次 \(a_i=a_i-(a_i\bmod j)(1\le j\le k)\) 的操作,所能达到的最小值是 \(a_i=a_i-(a_i\bmod M)\),其中 \(M=lcm(1,2,...,k)\)。
然而前 \(17\) 个数的 \(lcm\) 还有点大,可以观察到第 \(17\) 次的修改是无意义的,因为没有第 \(18\) 次的询问了。所以我们实际上只用求前 \(16\) 个数的 \(lcm\),这个 \(M\) 是小于 \(10^6\) 的。
换言之我们令 \(b_i=\left\lfloor \frac{a_i}{M} \right\rfloor\times M\),那么答案里肯定有一部分为 \(kn^{k-1}\sum b_i\)。
那么这个时候自然地想设计一个 \(O(Mk)\) 的 \(dp\)。符合直觉的想法是设 \(dp(i,j)\) 代表前 \(i\) 轮修改过后,所有方案里 \(j\) 的出现个数的总和。
考虑统计完 \(dp(i,j)\) 后对答案的贡献。需要注意,这里不是只有 \(dp(k,x)\) 这样的状态有用的,因为,每一轮的修改前都会有一次对答案的查询(就是 \(ans=ans+a_idx\)),所以对于每个 \(dp(i,j)(0\le i\lt k)\),都应该有贡献。一个 \(dp(i,j)\) 的贡献应该为 \(j\,\times\,dp(i,j)\,\times\,n^{k-i-1}\),这是自然的,其意义是,我第 \(i+1\) 轮,选择一个 \(j\) 就会有 \(j\) 的贡献,在所有局面里共有 \(dp(i,j)\) 种选择方式,第 \(i+2\,\sim k\) 轮的选择方式共有 \(n^{k-i-1}\) 种。
转移是本题的难点。首先 \(dp(0,j)=cnt_j\),\(cnt_j\) 是初始 \(a\) 中,元素 \(j\) 的出现次数。发现不好确定一个状态由哪些状态转移而来,因此考虑刷表:即研究一个状态 \(dp(i,j)\) 对 \(dp(i+1,j')\) 的贡献。
设某个局面有 \(k\) 个值为 \(j\) 的元素,当我选了它,那么下一轮会剩下 \((k-1)\) 个元素 \(j\);当我不选它,下一轮会剩下 \(k\) 个元素 \(j\)。所以这个局面对下一轮的 \(j\) 贡献为 \(k\,\times\,(k-1)+(n-k)\,\times\,k=k\,\times\,(n-1)\)。那么所有局面的 \(k\),加起来,就是 \(dp(i,j)\),也就是说 \(dp(i,j)\,\times\,(n-1)\rightarrow dp(i+1,j)\)。
但是考虑一个事情,当我选了一个值为 \(j\) 的元素中的一个,那么会有 \((k-1)\) 个 \(j\) 不假,但是这个元素本身还会变化,意思是给别的值 \(j'\) 的 \(dp(i+1,j')\) 造成贡献。具体地,如果第 \((k+1)\) 轮 \(j\) 这个值会变成 \(j'\),那么对于一个有 \(k\) 个值为 \(j\) 的元素的局面,我有 \(k\) 种选择,让下一轮的局面多一个 \(j'\) 出来,那么所有局面的 \(k\) 加起来,就是 \(dp(i,j)\),也就是说 \(dp(i,j)\rightarrow dp(i+1,j')\)。
我们发现这样转移是 \(O(1)\) 的,时间复杂度 \(O(Mk)\),空间复杂度可以通过滚动数组优化打到 \(O(M)\)。
#include<bits/stdc++.h>
#define rep(i,a,b) for(int i=(a);i<=(b);i++)
#define per(i,a,b) for(int i=(a);i>=(b);i--)
#define op(x) ((x&1)?x+1:x-1)
#define odd(x) (x&1)
#define even(x) (!odd(x))
#define lc(x) (x<<1)
#define rc(x) (lc(x)|1)
#define lowbit(x) (x&-x)
#define Max(a,b) (a>b?a:b)
#define Min(a,b) (a<b?a:b)
#define next Cry_For_theMoon
#define il inline
#define pb(x) push_back(x)
#define is(x) insert(x)
#define sit set<int>::iterator
#define mapit map<int,int>::iterator
#define pi pair<int,int>
#define ppi pair<int,pi>
#define pp pair<pi,pi>
#define fr first
#define se second
#define vit vector<int>::iterator
#define mp(x,y) make_pair(x,y)
typedef long long ll;
typedef unsigned long long ull;
typedef unsigned int uint;
typedef double db;
using namespace std;
const int MAXN=1e7+10,MAXM=1e6,mod=998244353;
ll n,a[MAXN],x,y,k,M,m;
ll ans,power[20],dp[2][MAXM];
ll gcd(ll a,ll b){if(!b)return a;return gcd(b,a%b);}
int main(){
cin>>n>>a[1]>>x>>y>>k>>M;
power[0]=1;rep(i,1,k)power[i]=power[i-1]*n%mod;
rep(i,2,n)a[i]=(a[i-1]*x+y)%M;
m=1;
rep(i,2,k-1)m=m/gcd(m,i)*i;
rep(i,1,n)ans=(ans+((a[i]/m)*m)*k%mod*power[k-1])%mod,a[i]%=m,dp[0][a[i]]++;
rep(i,0,k-1){
rep(j,0,m-1)ans=(ans+j*dp[i&1][j]%mod*power[k-i-1])%mod;
rep(j,0,m-1){
ll& ret=dp[(i+1)&1][j-j%(i+1)];
ret=(ret+dp[i&1][j])%mod;
ll& ret2=dp[(i+1)&1][j];
ret2=(ret2+dp[i&1][j]*(n-1))%mod;
}
memset(dp[i&1],0,sizeof dp[i&1]);
}
cout<<ans;
return 0;
}

浙公网安备 33010602011771号