loj #6485. LJJ 学二项式定理 (模板qwq)

$ \color{#0066ff}{ 题目描述 }$

LJJ 学完了二项式定理,发现这太简单了,于是他将二项式定理等号右边的式子修改了一下,代入了一定的值,并算出了答案。

但人口算毕竟会失误,他请来了你,让你求出这个答案来验证一下。

一共有 \(T\) 组数据,每组数据如下:

输入以下变量的值:\(n, s , a_0 , a_1 , a_2 , a_3\),求以下式子的值:

\(\begin{aligned}\Large \left[ \sum_{i=0}^n \left( {n\choose i} \cdot s^{i} \cdot a_{i\bmod 4} \right) \right] \bmod 998244353\end{aligned}\)

其中 \(n\choose i\) 表示 \(\frac{n!}{i!(n-i)!}\)

\(\color{#0066ff}{输入格式}\)

第一行一个整数 \(T\),之后 \(T\) 行,一行六个整数 \(n, s, a_0, a_1, a_2, a_3\)

\(\color{#0066ff}{输出格式}\)

一共 \(T\) 行,每行一个整数表示答案。

\(\color{#0066ff}{输入样例}\)

6
1 2 3 4 5 6
2 3 4 5 6 1
3 4 5 6 1 2
4 5 6 1 2 3
5 6 1 2 3 4
6 1 2 3 4 5

\(\color{#0066ff}{输出样例}\)

11
88
253
5576
31813
232

\(\color{#0066ff}{数据范围与提示}\)

对于 \(50\%\) 的数据,\(T \times n \leq 10^5\)

对于 \(100\%\) 的数据,\(1 \leq T \leq 10^5, 1 \leq n \leq 10 ^ {18}, 1 \leq s, a_0, a_1, a_2, a_3 \leq 10^{8}\)

\(\color{#0066ff}{题解}\)

一个有关n次单位根的公式

\[[n|k]=\frac 1 n \sum_{i=0}^{n-1}\omega_n^{ki} \]

就不证明了不会

因此,有

\[\sum_{i=0}^na_i[N|i]=\frac{\sum_{i=0}^{n}a_i\sum_{j=0}^{N-1}\omega_N^{ij}}{N}=\frac{\sum_{i=0}^{N-1}f(\omega_N^i)}{N} \]

其中\(f\)是数列a的生成函数

对于本题,我们考虑把4种情况分开处理,即

\[\sum_{i=0}^3a_i\sum_{j=0}^n[j\bmod4 = i]C_n^i*s^i \]

构造生成函数

\[\sum_{i=0}^3a_i\sum_{j=0}^nC_n^i*s^i * x^i*1^{n-i} \]

\[(sx+1)^n \]

但是,对于\(i\in[1,3]\)怎么处理呢?

考虑平移,把多项式整体乘上一个自变量,便是向右平移了一次

因此,只需变为\(\frac {f(\omega_N^j)}{\omega_N^{ij\bmod 4}}\)即可

#include<bits/stdc++.h>
#define LL long long
LL in() {
	char ch; LL x = 0, f = 1;
	while(!isdigit(ch = getchar()))(ch == '-') && (f = -f);
	for(x = ch ^ 48; isdigit(ch = getchar()); x = (x << 1) + (x << 3) + (ch ^ 48));
	return x * f;
}
const int mod = 998244353;
LL w[4], a[4], n, s, c[4];
LL ksm(LL x, LL y) {
	LL re = 1LL;
	while(y) {
		if(y & 1) re = re * x % mod;
		x = x * x % mod;
		y >>= 1;
	}
	return re;
}
int main() {
	w[0] = 1;
	LL g = ksm(3, (mod - 1) / 4);
	for(int i = 1; i <= 3; i++) w[i] = w[i - 1] * g % mod;
	for(int T = in(); T --> 0;) {
		n = in(), s = in(), a[0] = in(), a[1] = in(), a[2] = in(), a[3] = in();
		LL ans = 0;
		for(int i = 0; i < 4; i++) {
			c[i] = 0;
			for(int j = 0; j < 4; j++) (c[i] += ksm((s * w[j] + 1) % mod, n) * ksm(w[i * j % 4], mod - 2) % mod) %= mod;
			(c[i] *= ksm(4, mod - 2)) %= mod;
			(ans += a[i] * c[i] % mod) %= mod;
		}
		printf("%lld\n", ans);
	}
	return 0;
}
posted @ 2019-02-26 16:55  olinr  阅读(440)  评论(0编辑  收藏  举报