傅里叶和矩阵

傅里叶和矩阵

这是一道有意思的题目,给我指明了多项式的考试方向?

题目描述

一个 \(m\)\(n\) 列的矩阵,第 \(i\) 列有一个权值 \(a_i\),你要到 \((k,n)\) 这个点去,然后你经过的路径假如有 \((x,y)\) ,那么答案乘上 \(a_y\),然后捏,求最后的答案,有 \(q\) 次询问,每次给定一个 \(k\)。求每次的答案。答案对于 \(998244353\) 取模。

其中, \(1\le k\le m\le le 10^{18}\)\(qn\le 10^6,n\le 10^5\)

题目解答

我们考虑,一定是每一列都有一个 格子,然后剩下还有 \(k-1\) 个格子这 \(n\) 行分。

求的是

\[[x^{k-1}]\prod _{i=1}^n\frac{a_i}{1-a_ix} \]

对于原式强行转化,我们考虑用分式分解,

\[[x^{k-1}] (-1)^n \sum_{i=1}^n \frac{K_i}{x-a_i^{-1}} \]

考虑 \(\frac{1}{x-a_i^{-1}}\) 是啥,因为 \(\frac{1}{1-ax}=\sum _{j\ge 0} (ax) ^ j\),所以他应该是

\[\frac{1}{x-a_i^{-1}}=-a_i\frac{1}{1-a_ix}=-\sum_{j\ge 0} a_i ^{j+1} x^j \]

那么原式应该是 \((-1)^{n+1}\sum_{i=1}^n a_i ^ k K_i\)

我们考虑如何求出来 \(K_i\),我们考虑式子 \(\sum _{i=1}^n \frac{K_i}{x-r_i}=\prod _{i=1}^ n \frac{1}{x-r_i}\) ,两边同时乘上 \(\prod_{i=1}^ n (x-r_i)\),得到 \(\sum _{i=1} ^ n K_i\prod _{j\not = i} (x-r_j)=1\)

这个式子很像拉格朗日插值里面的插值多项式,定义 \(g(x)=\prod_{i=1}^ n (x-r_i)\),然后 \(K_i\) 就是 \(\frac{g(x)}{(x-r_i)}\)\(r_i\) 处的点值的逆。

这咋做,不会啊。我们考虑洛必达法则。他就是 \(g'(r_i)\)

然后直接多点求值。

AC代码

#include <bits/stdc++.h>
using std::cin;
using std::cout;
using std::vector;
using std::copy;
using std::reverse;
using std::sort;
using std::get;
using std::unique;
using std::swap;
using std::array;
using std::cerr;
using std::function;
using std::map;
using std::set;
using std::pair;
using std::mt19937;
using std::make_pair;
using std::tuple;
using std::make_tuple;
using std::uniform_int_distribution;
using ll = long long;
namespace qwq {
	mt19937 eng;
	void init(int Seed) { return eng.seed(Seed); }
	int rnd(int l = 1, int r = 1000000000) { return uniform_int_distribution<int> (l, r)(eng); }
}
template <typename T>
inline T min(const T &x, const T &y) { return x < y ? x : y; }
template<typename T>
inline T max(const T &x, const T &y) { return x > y ? x : y; }
template <typename T>
inline void read(T &x) {
	x = 0;
	bool f = 0;
	char ch = getchar();
	while (!isdigit(ch)) f = ch == '-', ch = getchar();
	while (isdigit(ch)) x = x * 10 + ch - '0', ch = getchar();
	if (f) x = -x;
}
template<typename T, typename ...Arg>
inline void read(T &x, Arg &... y) {
	read(x);
	read(y...);
}
#define O(x) cerr << #x << " : " << x << '\n'
const double Pi = acos(-1);
const int MAXN = 262144, MOD = 998244353, inv2 = (MOD + 1) / 2, I32_INF = 0x3f3f3f3f;
const long long I64_INF = 0x3f3f3f3f3f3f3f3f;
auto Ksm = [] (int x, int y) -> int {
	if (y < 0) {
		y %= MOD - 1;
		y += MOD - 1;
	}
	int ret = 1;
	for (; y; y /= 2, x = (long long) x * x % MOD) if (y & 1) ret = (long long) ret * x % MOD;
	return ret;
};
auto Mod = [] (int x) -> int {
	if (x >= MOD) return x - MOD;
	else if (x < 0) return x + MOD;
	else return x;
};
inline int ls(int k) { return k << 1; }
inline int rs(int k) { return k << 1 | 1; }
int fac[MAXN], ifac[MAXN];
namespace POLY {
int SZ, R[MAXN], W[MAXN], INV[MAXN + 1];
void POLYINIT() {
	INV[1] = 1;
	for (int i = 2; i <= MAXN; ++i) INV[i] = (long long) (MOD - MOD / i) * INV[MOD % i] % MOD;
}
void INIT(int len) {
	if (SZ == len) return;
	SZ = len;
	for (int i = 1; i < len; ++i) R[i] = (R[i >> 1] >> 1) | (i & 1 ? (len >> 1) : 0);
	int wn = Ksm(3, (MOD - 1) / len);
	W[len >> 1] = 1;
	for (int i = (len >> 1) + 1; i < len; ++i) W[i] = (long long) W[i - 1] * wn % MOD;
	for (int i = (len >> 1) - 1; i > 0; --i) W[i] = W[i << 1];
}
void Ntt(vector<int>& F, int limit, int type) {
	static unsigned long long c[MAXN];
	copy(F.begin(), F.begin() + limit, c);
	for (int i = 1; i < limit; ++i) if (i < R[i]) swap(c[i], c[R[i]]);
	for (int o = 2, j = 1; o <= limit; o <<= 1, j <<= 1) {
		for (int i = 0; i < limit; i += o) {
			for (int k = 0; k < j; ++k) {
				unsigned long long OI = c[i + j + k] * W[k + j] % MOD;
				c[i + j + k] = c[i + k] + MOD - OI;
				c[i + k] += OI;
			}
		}
	}
	if (type == -1) {
		reverse(c + 1, c + limit);
		int inv = INV[limit];
		for (int i = 0; i < limit; ++i) c[i] = c[i] % MOD * inv % MOD;
	}
	for (int i = 0; i < limit; ++i) F[i] = c[i] % MOD;
}
int w;
typedef std::pair<int, int> complex;
complex operator + (complex &x, complex &y) {
	return std::make_pair((x.first + y.first) % MOD, (x.second + y.second) % MOD);
}
complex operator - (complex &x, complex &y) {
	return std::make_pair((x.first - y.first + MOD) % MOD, (x.second - y.second + MOD) % MOD);
}
complex operator * (complex &x, complex &y) {
	return std::make_pair(((long long) x.first * y.first + (long long) w * x.second % MOD * y.second) % MOD, ((long long) x.first * y.second + (long long) y.first * x.second) % MOD);
}
complex ksm(complex x, int y) {
	complex res(1, 0);
	for (; y; y /= 2, x = x * x) if (y & 1) res = res * x;
	return res;
}
bool check(int x) { return Ksm(x, (MOD - 1) / 2) == 1; }
int Sqrt(int x) {
	if (!x) return x;
	else {
		int a = qwq::rnd() % MOD;
		while (check(((long long) a * a + MOD - x) % MOD)) a = qwq::rnd() % MOD;
		w = ((long long) a * a + MOD - x) % MOD;
		complex b(a, 1);
		int ans1(ksm(b, (MOD + 1) / 2).first);
		return min(ans1, MOD - ans1);
	}
}
struct Poly {
	vector<int> v;
	int& operator [] (const int &pos) { return v[pos]; }
	int len() { return v.size(); }
	void set(int l) { return v.resize(l); }
	void adjust() { while (v.size() > 1 && !v.back()) v.pop_back(); }
	void rev() { reverse(v.begin(), v.end()); }
	void Ntt(int L, int type) {
		int limit = 1 << L;
		INIT(limit);
		set(limit);
		POLY::Ntt(v, limit, type);
	}
	void Squ() {
		int L = ceil(log2(len())) + 1, limit = 1 << L;
		Ntt(L, 1);
		for (int i = 0; i < limit; ++i) v[i] = (long long) v[i] * v[i] % MOD;
		Ntt(L, -1);
		adjust();
	}
	void operator += (Poly &x) {
		if (len() < x.len()) set(x.len());
		for (int i = 0; i < x.len(); ++i)
			v[i] = Mod(v[i] + x[i]);
		adjust();
	}
	void operator -= (Poly &x) {
		if (len() < x.len()) set(x.len());
		for (int i = 0; i < x.len(); ++i) v[i] = Mod(v[i] - x[i]);
		adjust();
	}
	Poly operator * (Poly &x) {
		Poly ret, tmp0 = *this, tmp1 = x;
		int L = ceil(log2(tmp0.len() + tmp1.len() - 1)), n = 1 << L;
		Ntt(L, 1);
		x.Ntt(L, 1);
		ret.set(n);
		for (int i = 0; i < n; ++i) ret[i] = (long long) x[i] * v[i] % MOD;
		ret.Ntt(L, -1);
		ret.adjust();
		*this = tmp0;
		x = tmp1;
		return ret;
	}
	Poly operator - (Poly &x) {
		Poly ret;
		ret.set(max(len(), x.len()));
		for (int i = 0; i < len(); ++i) ret[i] = v[i];
		for (int i = 0; i < x.len(); ++i) ret[i] = Mod(ret[i] - x[i]);
		return ret;
	}
	Poly operator + (Poly &x) {
		Poly ret;
		ret.set(max(len(), x.len()));
		for (int i = 0; i < len(); ++i) ret[i] = v[i];
		for (int i = 0; i < x.len(); ++i) ret[i] = Mod(ret[i] + x[i]);
		return ret;
	}
	void operator *= (Poly &x) {
		Poly tmp = x;
		int L = ceil(log2(len() + x.len() - 1)), n = 1 << L;
		Ntt(L, 1);
		x.Ntt(L, 1);
		for (int i = 0; i < n; ++i) v[i] = (long long) v[i] * x[i] % MOD;
		Ntt(L, -1);
		adjust();
		x = tmp;
	}
	Poly GetInv(int deg = -1) {
		if (deg == 1) return {{Ksm(v[0], MOD - 2)}};
		Poly ret = GetInv((deg + 1) / 2), tmp;
		int L = ceil(log2(deg)) + 1, n = 1 << L, mx = min(len(), deg);
		tmp.set(deg);
		for (int i = 0; i < mx; ++i) tmp[i] = v[i];
		tmp.Ntt(L, 1);
		ret.Ntt(L, 1);
		for (int i = 0; i < n; ++i) ret[i] = (2 - (long long) tmp[i] * ret[i] % MOD + MOD) * ret[i] % MOD;
		ret.Ntt(L, -1);
		ret.set(deg);
		return ret;
	}
	pair<Poly, Poly> operator % (Poly &x) {
		if (x.len() > len()) return {{{0}}, *this};
		Poly tmp0 = *this, tmp1 = x;
		tmp0.rev();
		tmp1.rev();
		tmp1 = tmp1.GetInv(len() - x.len() + 1);
		tmp0 *= tmp1;
		tmp0.set(len() - x.len() + 1);
		tmp0.rev();
		tmp1 = tmp0 * x;
		Poly ret = *this - tmp1;
		ret.set(x.len() - 1);
		return {tmp0, ret};
	}
	vector<int> getmulpointvalue(vector<int> &x) {
		static Poly tmp[MAXN * 4];
		function<void(int, int, int)> get = [&] (int u, int l, int r) -> void {
			if (l == r) {
				tmp[u] = {{Mod(-x[l]), 1}};
				return;
			}
			int mid = (l + r) / 2;
			get(ls(u), l, mid);
			get(rs(u), mid + 1, r);
			tmp[u] = tmp[ls(u)] * tmp[rs(u)];
		};
		get(1, 0, x.size() - 1);
		vector<int> ret(x.size());
		function<void(int, int, Poly, int)> solve = [&] (int l, int r, Poly f, int u) -> void {
			if (l == r) {
				ret[l] = f[0];
				return;
			}
			int mid = (l + r) / 2;
			solve(l, mid, (f % tmp[ls(u)]).second, ls(u));
			solve(mid + 1, r, (f % tmp[rs(u)]).second, rs(u));
		};
		solve(0, x.size() - 1, (*this % tmp[1]).second, 1);
		return ret;
	}
	Poly Dif(int deg = -1) {
		Poly tmp;
		tmp.set(max(len() - 1, 1));
		for (int i = 0; i < len() - 1; ++i) tmp[i] = v[i + 1] * (i + 1LL) % MOD;
		if (~deg) tmp.set(deg);
		return tmp;
	}
	Poly GetSqrt(int deg = -1) {
		if (deg == 1) return {{POLY::Sqrt(v[0])}};
		Poly ret = GetSqrt((deg + 1) / 2), tmp0 = ret.GetInv(deg), tmp1;
		int L = ceil(log2(deg)) + 1, mx = min(len(), deg);
		tmp1.set(deg);
		for (int i = 0; i < mx; ++i) tmp1[i] = v[i];
		tmp0 *= tmp1;
		tmp0.set(deg);
		ret += tmp0;
		for (auto &i: ret.v) i = (long long) i * inv2 % MOD;
		return ret;
	}
	Poly Int(int deg = -1) {
		Poly tmp;
		tmp.set(len() + 1);
		for (int i = 1; i < tmp.len(); ++i) tmp[i] = (long long) v[i - 1] * INV[i] % MOD;
		if (~deg) tmp.set(deg);
		return tmp;
	}
	Poly GetLn(int deg = -1) {
		Poly tmp0 = Dif(deg), tmp1 = GetInv(deg);
		tmp0 *= tmp1;
		tmp0.set(deg);
		return tmp0.Int(deg);
	}
	Poly GetExp(int deg = -1) {
		if (deg == 1) return {{1}};
		Poly tmp0 = GetExp((deg + 1) / 2), tmp1 = tmp0.GetLn(deg);
		for (int i = 0; i < deg; ++i) tmp1[i] = Mod(v[i] - tmp1[i]);
		tmp1[0] = Mod(tmp1[0] + 1);
		tmp0 *= tmp1;
		tmp0.set(deg);
		return tmp0;
	}
	Poly GetKsm(int deg, int K) {
		Poly tmp0 = GetLn(deg);
		for (auto &i: tmp0.v) i = (long long) i * K % MOD;
		return tmp0.GetExp(deg);
	}
};
vector<int> Interpolation(vector<pair<int, int>> &x) {
	static Poly tmp[MAXN * 4];
	vector<int> tmpx;
	for (auto &i: x) tmpx.push_back(i.first);
	function<void(int, int, int)> get = [&] (int u, int l, int r) -> void {
		if (l == r) {
			tmp[u] = {{Mod(-tmpx[l]), 1}};
			return;
		}
		int mid = (l + r) / 2;
		get(ls(u), l, mid);
		get(rs(u), mid + 1, r);
		tmp[u] = tmp[ls(u)] * tmp[rs(u)];
	};
	get(1, 0, tmpx.size() - 1);
	Poly ret = tmp[1];
	ret = ret.Dif();
	vector<int> tmpy = ret.getmulpointvalue(tmpx);
	for (int i = 0; i < x.size(); ++i) tmpy[i] = (long long) Ksm(tmpy[i], MOD - 2) * x[i].second % MOD;
	function<Poly(int, int, int)> solve = [&] (int u, int l, int r) -> Poly {
		if (l == r) return {{tmpy[l]}};
		int mid = (l + r) / 2;
		Poly tmp0 = solve(ls(u), l, mid) * tmp[rs(u)];
		Poly tmp1 = solve(rs(u), mid + 1, r) * tmp[ls(u)];
		return tmp0 + tmp1;
	};
	return solve(1, 0, x.size() - 1).v;
}
}
using namespace POLY;
int N, Q;
long long M;
Poly solve(vector<int> &f) {
	function<Poly(int, int)> dfs = [&] (int l, int r) -> Poly {
		if (l == r) return {{Mod(-f[l]), 1}};
		int mid = (l + r) / 2;
		Poly tmp0 = dfs(l, mid), tmp1 = dfs(mid + 1, r);
		return tmp0 * tmp1;
	};
	return dfs(0, f.size() - 1);
}
int main() {
	// freopen("C.in", "r", stdin);
	qwq::init(20050112);
	POLYINIT();
	// cout << (-3 / 2);
	read(M, N, Q);
	vector<int> r(N), A(N);
	for (int i = 0; i < N; ++i) {
		read(A[i]);
		r[i] = Ksm(A[i], MOD - 2);
	}
	Poly f = solve(r);
	if (M <= 100000) {
		f = f.GetInv(M);
		for (int i = 1, k; i <= Q; ++i) {
			read(k);
			printf("%d\n",  Mod(N & 1 ? -f[k - 1] : f[k - 1]));
		}
		return 0;
	}
	f = f.Dif();
	vector<int> cyc = f.getmulpointvalue(r);
	for (auto &i: cyc) i = Ksm(i, MOD - 2);
	long long k;
	for (int i = 1; i <= Q; ++i) {
		read(k);
		k %= MOD - 1;
		int ans = 0;
		for (int j = 0; j < N; ++j) ans = (ans + (ll) cyc[j] * Ksm(A[j], k)) % MOD;
		printf("%d\n", Mod(N & 1 ? ans : -ans));
	}
	cerr << ((double) clock() / CLOCKS_PER_SEC) << '\n';
	return (0-0);
}
posted @ 2022-03-26 23:16  siriehn_nx  阅读(100)  评论(0)    收藏  举报