do_while_true

一言(ヒトコト)

转置原理

去年尝试理解过,今年再看才学懂了一点。基本抄的 EI & qwaszx 的课件。

简介

转置原理给出的是,通过 \(\mathbf{b}=A\mathbf{a}\) 的算法来解决 \(\mathbf{\hat b}=A^{T}\mathbf{\hat a}\),这里 \(\mathbf{a}\)\(\mathbf{\hat a}\) 是输入给定的向量,而 \(A\) 是一个常量矩阵。

转置原理要求得到的 \(\mathbf{b}\) 中每个元素都是 \(\mathbf{a}\) 中元素的线性组合,换而言之,\(\mathbf{b}\) 中的元素(也包括算法过程中)不允许出现 \(\mathbf{a}\) 的非一次项(也就是出现 \(a_ia_j,{a_i}^{2}\) 这样的),这代表了这个算法是一个线性算法。

线性算法在分解中 \(A\) 完成计算,即计算 \(A_1\cdots A_m\mathbf{a}\),其中 \(A\) 均为初等矩阵(或者是比较简单的矩阵?)。那么其转置问题即为计算 \({A_m}^T\cdots {A_1}^T\mathbf{a}\),这称之为原算法的转置算法。也就是将原算法的操作逆序执行它们的转置,即得到了转置后的算法。

例子

例如在前缀和算法中,输入给定了需要作前缀和的 \(\mathbf{a}\),将其左乘一个矩阵得到前缀和后的结果 \(A\mathbf{a}=\mathbf{b}\),不难写出 \(A\) 是主对角线及以下均为 \(1\),其余位置都为 \(0\) 的矩阵,而将其转置后 \(A^T\) 是主对角线及以上均为 \(1\),其余位置均为 \(0\) 的矩阵,不难验证 \(A^T\mathbf{a}\) 即为对 \(\mathbf{a}\) 作后缀和,这意味着前缀和算法转置后得到了后缀和算法。

试着写写 \(b_i\gets b_i+cb_{j}\) 的转置:

\[\begin{pmatrix} 1 & c\\ 0 & 1 \end{pmatrix} \begin{bmatrix} b_i \\ b_j \end{bmatrix} = \begin{bmatrix} b_i+cb_j \\ b_j \end{bmatrix} \]

转置后:

\[\begin{pmatrix} 1 & 0\\ c & 1 \end{pmatrix} \begin{bmatrix} b_i \\ b_j \end{bmatrix} = \begin{bmatrix} b_i \\ cb_i+b_j \end{bmatrix} \]

所以 \(b_i\gets b_i+cb_{j}\) 转置后得到了 \(b_{j}\gets b_j+cb_i\)

常见操作的转置:

\(a_i\gets a_i+ ca_j\) \(a_j\gets a_j + ca_i\)
\(swap(a_i,a_j)\) \(swap(a_i,a_j)\)
\(a_i\gets a_j\) \(a_j\gets a_i+a_j,a_i=0\)
\(a_i\gets ca_i\) \(a_i\gets ca_i\)

FFT

我们都知道 FFT 是将 \(n\) 次单位根的若干次幂代入多项式得到点值。不难写出其对应的矩阵 \(A\)\(\omega ^{ij}\),那么就有 \(A=A^T\),这意味着将 FFT 的转置算法和 FFT 的效果一样。转置后 bitrev 依然是 bitrev,但是注意到 bitrev 会在迭代后再进行,而 IDFT 开头也有个 bitrev,那么两个 bitrev 互相抵消,可以直接省略这个 bitrev,从而减少常数。

我将自己本来就跑的不是很快的 NTT 转置后,在洛谷上测试 \(10^6\) 的多项式乘法,每个点快了 150~ 200 ms.

#include<iostream>
#include<cstdio>
#include<algorithm>
#include<vector>
#include<bitset>
#define pb emplace_back
#define mp std::make_pair
#define fi first
#define se second
using namespace std;
typedef long long ll;
typedef long double ld;
typedef unsigned long long ull;
typedef std::pair<int, int> pii;
typedef std::vector<int> vi;
const ll mod = 998244353;
ll Add(ll x, ll y) { return (x+y>=mod) ? (x+y-mod) : (x+y); }
ll Mul(ll x, ll y) { return x * y % mod; }
ll Mod(ll x) { return x < 0 ? (x + mod) : (x >= mod ? (x-mod) : x); }
ll cadd(ll &x, ll y) { return x = (x+y>=mod) ? (x+y-mod) : (x+y); }
ll cmul(ll &x, ll y) { return x = x * y % mod; }
template <typename T> T Max(T x, T y) { return x > y ? x : y; }
template<typename T, typename... T2> T Max(T x, T2 ...y) { return Max(x, y...); }
template <typename T> T Min(T x, T y) { return x < y ? x : y; }
template<typename T, typename... T2> T Min(T x, T2 ...y) { return Min(x, y...); }
template <typename T> T cmax(T &x, T y) { return x = x > y ? x : y; }
template <typename T> T cmin(T &x, T y) { return x = x < y ? x : y; }
template <typename T>
T &read(T &r) {
	r = 0; bool w = 0; char ch = getchar();
	while(ch < '0' || ch > '9') w = ch == '-' ? 1 : 0, ch = getchar();
	while(ch >= '0' && ch <= '9') r = r * 10 + (ch ^ 48), ch = getchar();
	return r = w ? -r : r;
}
template<typename T1, typename... T2>
void read(T1 &x, T2& ...y) { read(x); read(y...); }
ll qpow(ll x, ll y) {
	ll s = 1;
	while(y) {
		if(y & 1) s = s * x % mod;
		x = x * x % mod;
		y >>= 1;
	}
	return s;
}
const int N = 4000010;
ll *getw(int n, int type) {
	static ll w[N/2];
	w[0] = 1; w[1] = qpow(type == 1 ? 3 : 332748118, (mod-1)/n);
	for(int i = 2; i < n/2; ++i) w[i] = w[i-1] * w[1] % mod;
	return w;
}
void DFT(ll *a, int n) { //转置
	for(int i = n/2; i; i >>= 1) {
		ll *w = getw(i << 1, 1);
		for(int j = 0; j < n; j += i << 1) {
			ll *b = a + j, *c = b + i;
			for(int k = 0; k < i; ++k) {
				ll u = b[k], v = c[k];
				b[k] = (u + v) % mod;
				c[k] = Add((u * w[k]) % mod, Mod(-v * w[k] % mod));
			}
		}
	}
}
void IDFT(ll *a, int n) {
	for(int i = 1; i < n; i <<= 1) {
		ll *w = getw(i << 1, -1);
		for(int j = 0; j < n; j += i << 1) {
			ll *b = a + j, *c = b + i;
			for(int k = 0; k < i; ++k) {
				ll v = c[k] * w[k] % mod;
				c[k] = Add(b[k], Mod(-v));
				cadd(b[k], v);
			}
		}
	}
	ll inv = qpow(n, mod-2);
	for(int i = 0; i < n; ++i) a[i] = a[i] * inv % mod;
}
int n, m, len = 1, ct;
ll f[N], g[N];
signed main() {
	read(n); read(m);
	for(int i = 0; i <= n; ++i) read(f[i]);
	for(int i = 0; i <= m; ++i) read(g[i]);
	while(len <= n+m) len <<= 1, ++ct;
	DFT(f, len);
	DFT(g, len);
	for(int i = 0; i < len; ++i) f[i] = f[i] * g[i] % mod;
	IDFT(f, len);
	for(int i = 0; i <= n+m; ++i) printf("%lld ", f[i]);
	return 0;
}

Do Use FFT, GYM102978D

给定长为 \(N\) 的序列 \(A, B, C\),对 \(k = 1,\cdots, N\) 求出

\[\sum_{1 \leq i \leq N} \left( C_i \times \prod_{1 \leq j \leq k} (A_i+B_j) \right) \]

\(N\leq 2.5\times 10^5\),模 \(998244353\)

首先搞清楚谁是“输入”,这里只有 \(C\) 作为变量时是仅有一次项,所以将 \(C\) 看作输入,假设得到的答案为 \(q_1,q_2,\cdots ,q_N\),那么变换矩阵即为 \(M_{i,j}=\prod _{k\leq i}(A_j+B_k)\),将其转置得到:

\[C_i=\sum_j \left(q_j\prod_{k\leq j}(A_i+B_k)\right) \]

注意到当 \(i\) 改变的时候仅有 \(A_i\) 这里会改变,所以将其看作一个元 \(x\),则有:

\[C_i=F(A_i),F(x)=\sum_j \left(q_j\prod_{k\leq j}(x+B_k)\right) \]

欲求 \(F\),分治 FFT,分治到 \([l,r]\) 的时候维护 \(\sum_{l\leq j\leq r}\left(q_j\prod_{l\leq k\leq j}(x+B_k) \right)\)\(\prod_{l\leq k\leq r}(x+B_k)\) 即可,然后再对 \(F\) 多点求值即可得到答案。

考虑完转置问题的解决,来考虑原问题,那么将转置问题的算法转置过来即可。时间复杂度 \(\mathcal{O}(n\log^2 n)\)

因为不会多点求值,就不实现了。

多点求值

待学罢/ll/ll/dk

posted @ 2022-06-21 21:12  do_while_true  阅读(325)  评论(0编辑  收藏  举报