洛谷 P5548 - [BJ United Round #3] 押韵(单位根反演+幂级数)

洛谷题面传送门

毒瘤 EI 题,好不容易看懂了题解就写篇题解纪念下吧(

首先注意到可能的 \(d\) 个数很少,因此考虑对 \(d\) 进行分类讨论。

\(d=1\) 的情况没什么好说的,直接输出 \(k^n\) 即可。

对于 \(d\ge 2\) 的情况,显然有

\[ans=n![x^n](\sum\limits_{d\mid i}\dfrac{x^i}{i!})^k \]

注意到这里涉及对 \(i\bmod d=0\)\(i\) 计算贡献,不难想到单位根反演,调用单位根反演有

\[\begin{aligned} &\sum\limits_{i}\dfrac{x^i}{i!}[i\bmod d = 0]\\ =&\sum\limits_{i}\dfrac{x^i}{i!}\sum\limits_{j=0}^{d-1}\omega_{d}^{ij}\\ =&\sum\limits_{j=0}^{d-1}\sum\limits_{i}\dfrac{(\omega_d^jx)^i}{i!}\\ =&\sum\limits_{j=0}^{d-1}\exp(\omega_d^jx) \end{aligned} \]

\[ans=n![x^n](\sum\limits_{j=0}^{d-1}\exp(\omega_d^jx))^k \]

接下来我们的任务就是如何计算 \(k\) 次幂中的东西。

\(d=2/3\)

注意到我们暴力将括号展开后项数是 \(k^{d-1}\)​ 级别的,因此当 \(d=2\)​ 或 \(3\)​ 时直接枚举都是可以接受的。

时间复杂度 \(k^{d-1}\log n\)

\(d=4\)

\(d\ge 4\) 时,直接枚举复杂度就不能接受了。不过注意到 \(\omega_4^2=-\omega_4^0,\omega_4^3=-\omega_4^1\),也就是说,我们枚举展开式中选了 \(x\)\(\omega_4^0\)\(y\)\(\omega_4^1\)\(z\)\(\omega_4^2\)\(k-x-y-z\)\(\omega_4^3\),那么我们得到的和式 \(x·\omega_4^0+y·\omega_4^1+z·\omega_4^2+(k-x-y-z)·\omega_4^3\),实际上只用用两个数 \(X\omega_4^0+Y\omega_4^1\) 就可以描述,这启发我们枚举最终结果对应的数 \(X\omega_4^0+Y\omega_4^1\)——显然 \(|X|,|Y|\le k\),因此枚举量是 \(k^2\) 的,于是问题转化为有多少种从每个括号里选一项的方法能够得到 \(X\omega_4^0+Y\omega_4^1\)。注意到如果我们将 \((X,Y)\) 视作平面上一个点,那么从每个括号中选 \(\omega_4^0,\omega_4^1,\omega_4^2,\omega_4^3\) 就分别对应向右、向上、向左、向下走一格,问题自然就变为从 \((0,0)\)\((X,Y)\)\(k\) 步到达的方案数,这是一个经典问题,之前模拟赛中也有所提及,将坐标轴旋转 \(45\) 度后两维就独立了,方案数直接 \(\dbinom{k}{\dfrac{k-||X|-|Y||}{2}}·\dbinom{k}{\dfrac{k+|X|+|Y|}{2}}\) 即可。时间复杂度 \(k^2\log n\)

\(d=6\)

仿照 \(d=4\) 的做法,我们还是以 \(\omega_6^0\)\(\omega_6^1\)(以下简称 \(1\)\(\omega\))为基底,那么 \(\omega_6^2,\omega_6^3,\omega_6^4,\omega_6^5\) 就分别对应 \(-1+\omega,-1,-\omega,1-\omega\)。我们还是枚举结果对应的数 \(X+Y\omega\),如果我们能够像 \(d=4\) 一样有比较简便的方式计算出从每个括号中选一个数,最终到达 \((X,Y)\) 的方案数,那么我们直接对每个 \((X,Y)\) 都算一遍,乘上 \((X+Y\omega)^{6n}\) 并求和就做完了。于是问题转化为如何对所有 \((X,Y)\),求出在六边形网格上从 \((0,0)\)\((X,Y)\) 的方案数。

由于六边形网格没有什么特别的公式,我们只好求助于幂级数。显然 \(d=6\) 的情况,每个括号里的东西对应的幂级数就是 \(x^{1}y^0+x^0y^1+x^{-1}y^0+x^0y^{-1}+x^1y^{-1}+x^{-1}y^1\),这里涉及分式,略有点棘手,因此不妨先给它乘个 \(xy\),得到 \(F=x^2y^1+x^1y^2+x^0y^1+x^1y^0+x^2y^0+x^0y^2\),这样问题转化为求 \(G=F^k\)

考虑短多项式幂的求法:先将两边同时求导——由于是二元生成函数,我们对其求偏导,可得 \(\dfrac{\partial G}{\partial y}=k\dfrac{\partial F}{\partial x}·F^{k-1}\),两边同乘 \(F\) 之后得 \(F·\dfrac{\partial G}{\partial y}=k\dfrac{\partial F}{\partial x}·G\),这样我们就可以递推了。

相较于一般的短多项式幂,此题是二元幂级数,因此推起来可能略有点困难,这里稍微讲下推法,考虑上面乘积中 \([x^ny^m]\) 项,左边 \(=(m-1)G_{n,m-1}+(m-1)G_{n-1,m-1}+mG_{n,m}+mG_{n-1,m}+(m+1)G_{n-2,m+1}+(m+1)G_{n-1,m+1}\),右边 \(=kG_{n,m}+kG_{n-2,m}+2kG_{n,m-1}+2kG_{n-1,m-1}\)。一开始我的想法是直接按 \(n\) 递增,\(m\) 递增的顺序解出 \(G_{n,m}\),不过这样会出现分母为 \(0\) 的情况,正确的姿势是按 \(m\) 递增,\(n\) 递增的顺序解出 \(G_{n-1,m+1}\)。这样即可推出全部 \(G_{n,m}\)

时间复杂度 \(k^2\log n\),略有点卡常。

#include <bits/stdc++.h>
// #include <ext/pb_ds/assoc_container.hpp>
// #include <ext/pb_ds/hash_policy.hpp>
// #include <ext/pb_ds/priority_queue.hpp>
using namespace std;
// using namespace __gnu_pbds;
#define fi first
#define se second
#define fill0(a) memset(a, 0, sizeof(a))
#define fill1(a) memset(a, -1, sizeof(a))
#define fillbig(a) memset(a, 63, sizeof(a))
#define pb push_back
#define ppb pop_back
#define mp make_pair
#define mt make_tuple
#define eprintf(...) fprintf(stderr, __VA_ARGS__)
template <typename T1, typename T2> void chkmin(T1 &x, T2 y){
	if (x > y) x = y;
}
template <typename T1, typename T2> void chkmax(T1 &x, T2 y){
	if (x < y) x = y;
}
typedef pair<int, int> pii;
typedef long long ll;
typedef unsigned int u32;
typedef unsigned long long u64;
typedef long double ld;
namespace fastio {
	#define FILE_SIZE 1 << 23
	char rbuf[FILE_SIZE], *p1 = rbuf, *p2 = rbuf, wbuf[FILE_SIZE], *p3 = wbuf;
	inline char getc() {
		return p1 == p2 && (p2 = (p1 = rbuf) + fread(rbuf, 1, FILE_SIZE, stdin), p1 == p2) ? -1: *p1++;
	}
	inline void putc(char x) {(*p3++ = x);}
	template <typename T> void read(T &x) {
		x = 0; char c = getc(); T neg = 0;
		while (!isdigit(c)) neg |= !(c ^ '-'), c = getc();
		while (isdigit(c)) x = (x << 3) + (x << 1) + (c ^ 48), c = getc();
		if (neg) x = (~x) + 1;
	}
	template <typename T> void recursive_print(T x) {
		if (!x) return;
		recursive_print (x / 10);
		putc (x % 10 ^ 48);
	}
	template <typename T> void print(T x) {
		if (!x) putc('0');
		if (x < 0) putc('-'), x = -x;
		recursive_print(x);
	}
	template <typename T> void print(T x,char c) {print(x); putc(c);}
	void readstr(char *s) {
		char c = getc();
		while (c <= 32 || c >= 127) c = getc();
		while (c > 32 && c < 127) s[0] = c, s++, c = getc();
		(*s) = 0;
	}
	void printstr(string s) {
		for (int i = 0; i < s.size(); i++) putc(s[i]);
	}
	void printstr(char *s) {
		int len = strlen(s);
		for (int i = 0; i < len; i++) putc(s[i]);
	}
	void print_final() {fwrite(wbuf, 1, p3 - wbuf, stdout);}
}
const int MOD = 1049874433;
const int MAXK = 4001;
int fac[MAXK + 5], ifac[MAXK + 5], inv[MAXK + 5];
void init_fac(int n) {
	for (int i = (fac[0] = ifac[0] = inv[0] = inv[1] = 1) + 1; i <= n; i++)
		inv[i] = 1ll * inv[MOD % i] * (MOD - MOD / i) % MOD;
	for (int i = 1; i <= n; i++) {
		fac[i] = 1ll * fac[i - 1] * i % MOD;
		ifac[i] = 1ll * ifac[i - 1] * inv[i] % MOD;
	}
}
int qpow(int x, ll e) {
	int ret = 1;
	for (; e; e >>= 1, x = 1ll * x * x % MOD)
		if (e & 1) ret = 1ll * ret * x % MOD;
	return ret;
}
int n, k, d;
namespace sub1 {void solve() {printf("%d\n", qpow(k, n)); exit(0);}}
namespace sub2 {
	void solve() {
		int res = 0;
		for (int i = 0; i <= k; i++) res = (res + 1ll * fac[k] * ifac[i] % MOD * ifac[k - i] % MOD * qpow((i - (k - i) + MOD) % MOD, n * 2)) % MOD;
		res = 1ll * res * qpow(qpow(2, MOD - 2), k) % MOD;
		printf("%d\n", res); exit(0);
	}
}
const int G = 7;
namespace sub3 {
	void solve() {
		int W1 = qpow(G, (MOD - 1) / 3), W2 = qpow(G, (MOD - 1) / 3 * 2);
		int res = 0;
		for (int i = 0; i <= k; i++) for (int j = 0; j + i <= k; j++)
			res = (res + 1ll * fac[k] * ifac[i] % MOD * ifac[j] % MOD * ifac[k - i - j] % MOD *
			qpow((1ll * i * W1 + 1ll * j * W2 + (k - i - j)) % MOD, n * 3ll)) % MOD;
		res = 1ll * res * qpow(qpow(3, MOD - 2), k) % MOD;
		printf("%d\n", res); exit(0);
	}
}
namespace sub4 {
	int binom(int n, int m) {return 1ll * fac[n] * ifac[m] % MOD * ifac[n - m] % MOD;}
	void solve() {
		int W1 = qpow(G, (MOD - 1) / 4), res = 0;
		for (int i = -k; i <= k; i++) for (int j = -k; j <= k; j++) {
			if (abs(i) + abs(j) <= k && ((i + j) & 1) == (k & 1)) {
				int way = 1ll * binom(k, (k + abs(abs(i) - abs(j))) / 2) *
				binom(k, (k + abs(i) + abs(j)) / 2) % MOD;
//				printf("! %d %d %d\n", i, j, way);
				res = (res + 1ll * qpow(((i + 1ll * j * W1) % MOD + MOD) % MOD, n * 4ll) % MOD * way) % MOD;
			}
		}
		res = 1ll * res * qpow(qpow(4, MOD - 2), k) % MOD;
		printf("%d\n", res);
	}
}
namespace sub6 {
	int a[MAXK + 5][MAXK + 5], b[4][4], c[4][4];
	int binom(int n, int m) {return 1ll * fac[n] * ifac[m] % MOD * ifac[n - m] % MOD;}
	void solve() {
		for (int i = -1; i <= 1; i++) for (int j = -1; j <= 1; j++)
			if (i != j) b[i + 1][j + 1]++;
		for (int i = 0; i <= 2; i++) for (int j = 1; j <= 2; j++)
			c[i][j - 1] = b[i][j] * j * k;
//		for (int i = 0; i < 3; i++) for (int j = 0; j < 3; j++) printf("%d%c", c[i][j], " \n"[j == 2]);
		for (int i = 0; i <= k; i++) a[i][k - i] = binom(k, i);
		for (int i = k + 1; i <= k * 2; i++) a[i][0] = a[0][i] = binom(k, i - k);
		for (int i = 0; i < k * 2; i++) for (int j = max(k - i + 1, 2); j <= min(k * 3 - i, k * 2 + 1); j++) {
			// the i-th column, j-th row, find a[j - 1][i + 1]
			int ss = 0;
			for (int p = 0; p < 3; p++) for (int q = 0; q < 3; q++) {
				if (j - p >= 0 && i - q + 1 >= 0 && (p != 1 || q != 0) && b[p][q])
					ss = (ss - 1ll * (i - q + 1) * a[j - p][i - q + 1] * b[p][q] % MOD + MOD) % MOD;
			}
			for (int p = 0; p < 3; p++) for (int q = 0; q < 3; q++)
				if (j - p >= 0 && i - q >= 0 && c[p][q])
					ss = (ss + 1ll * c[p][q] * a[j - p][i - q]) % MOD;
			a[j - 1][i + 1] = 1ll * ss * inv[i + 1] % MOD;
		}
//		for (int i = 0; i <= k * 2; i++) for (int j = 0; j <= k * 2; j++)
//			printf("%d%c", a[i][j], " \n"[j == k * 2]);
		int W1 = qpow(G, (MOD - 1) / 6), res = 0;
		for (int i = 0; i <= k * 2; i++) for (int j = 0; j <= k * 2; j++)
			res = (res + 1ll * a[i][j] * qpow(((i - k + 1ll * W1 * (j - k)) % MOD + MOD) % MOD, n * 6ll)) % MOD;
		res = 1ll * res * qpow(qpow(6, MOD - 2), k) % MOD;
		printf("%d\n", res);
	}
}
int main() {
	scanf("%d%d%d", &n, &k, &d); init_fac(MAXK);
	if (d == 1) sub1 :: solve();
	if (d == 2) sub2 :: solve();
	if (d == 3) sub3 :: solve();
	if (d == 4) sub4 :: solve();
	if (d == 6) sub6 :: solve();
	return 0;
}
posted @ 2022-02-10 00:49  tzc_wk  阅读(97)  评论(0)    收藏  举报