题解:P7468 [NOI Online 2021 提高组] 愤怒的小 N

题意:有一个由以下方式生成的字符串:

  • 初始为 \(s=a\),每次令 \(s'\)\(s\)\(a\rightarrow b,b\rightarrow a\),然后令 \(s=s+s'\),重复无限次该操作。

然后给出一个数 \(n\) 和一个 \(k\) 项的多项式 \(f(x)\),求

\[\sum_{i=0}^{n-1} [s_i=b]f(i) \]

\(n\le 2^{5\times 10^5},k\le 500\)

做法:

首先一个朴素的想法,\(F_i(x)\) 代表 \(\sum\limits_{j=0}^{2^i-1}f(x+j)[s_{j}=a]\),类似定义 \(G\) 代表 \(b\) 的。那么很容易列出来一个倍增的转移式:

\[F_i(x) = F_{i-1}(x) + G_{i-1}(x+2^{i-1}) \]

\[G_i(x) = G_{i-1}(x) + F_{i-1}(x+2^{i-1}) \]

初值为 \(F_0(x) = f(x), G_0(x) = 0\)

答案计算就是从高位往低位考虑,第一段取 \(G\),第二段取 \(F\) 递归就可以了。

直接去做值域平移可以做到 \(O(k^2\log n)\)

考虑如何优化,这一步感觉太神秘了。考虑定义 \(H_i(x) = F_{i}(x) - G_{i}(x)\),那么可以写出来 \(H\) 的转移式:

\[H_i = H_{i-1}(x) - H_{i-1}(x+2^{i-1}) \]

发现里面出现了一个减法。这意味着什么呢,因为我值域平移是不影响最高位系数的,所以两个一减最高位就被吃掉了!所以意味着非 \(0\)\(H\) 其实只有 \(k\) 项,后面的 \(F,G\) 其实是一样的。

我们考虑用 \(H\) 带回去表示 \(F\),那么 \(F\) 就是 \(\frac{F+G+H}2\)\(G\) 就是 \(\frac{F+G-H}2\)。我们把贡献分成两部分,一部分是 \(F+G\),一部分是 \(H\)

\(H\) 很好算,就是我们上面说的暴力,复杂度 \(O(k^3)\)\(F+G\) 本质上加总展开就是 \(\sum\limits_{i=0}^{n-1}f(i)\)。注意到这个东西是个 \(k\) 项的多项式的前缀和,所以应该也是一个 \(k+1\) 项的多项式,直接暴力求 \(k+1\) 个出来插值就可以,复杂度 \(O(k^2)\)

总复杂度 \(O(k^3+n)\)

代码:

#include <bits/stdc++.h>
using namespace std;
#define int long long
const int maxn = 505, mod = 1e9 + 7, maxN = 5e5 + 5;
string s;
int n, m, C[maxn][maxn], pw[maxn], jc[maxn], revjc[maxn], pre[maxn], suf[maxn], inv[maxn];
struct Poly {
	vector<int> a;
	void resize(int N) {
		a.resize(N);
	}	
	int size() {
		return a.size();
	}
	int& operator[](int x) {
		return a[x];
	}
	void pop_back() {
		a.pop_back();
	}
	friend Poly operator+(Poly x, Poly y) {
		int n = x.size();
		for (int i = 0; i < n; i++)
			x[i] = (x[i] + y[i]) % mod;
		return x;
	}
	friend Poly operator-(Poly x, Poly y) {
		int n = x.size();
		for (int i = 0; i < n; i++)
			x[i] = (x[i] - y[i] + mod) % mod;
		while(x.size() && !x[x.size() - 1])
			x.pop_back();
		return x;
	}
	Poly shift(int k) {
		int n = size();
		Poly f; f.resize(n);
		pw[0] = 1;
		for (int i = 1; i <= n; i++)
			pw[i] = pw[i - 1] * k % mod;
		for (int i = 0; i < n; i++) 
			for (int j = 0; j <= i; j++) 
				f[j] = (f[j] + C[i][j] * a[i] % mod * pw[i - j] % mod) % mod;
		return f;
	}
	int queryp(int x) {
		int res = 0;
		for (int i = size() - 1; i >= 0; i--)
			res = (res * x + a[i]) % mod;
		return res;
	}
	int query(int x) {
		int n = size();
		if(x < n)
			return a[x];
		pre[0] = 1, suf[n] = 1;
		for (int i = 1; i < n; i++)
			pre[i] = pre[i - 1] * inv[i] % mod;
		for (int i = n - 1; i >= 1; i--)
			suf[i] = mod - suf[i + 1] * inv[(n - i)] % mod;
		int ans = 0, res = 1;
		for (int i = 1; i < n; i++) 
			res = res * (x - i) % mod, pre[i] = pre[i] * res % mod;
		res = 1;
		for (int i = n - 1; i >= 1; i--)
			res = res * (x - i) % mod, suf[i] = suf[i] * res % mod;
		for (int i = 1; i < n; i++)
			ans = (ans + a[i] * pre[i - 1] % mod * suf[i + 1] % mod) % mod;
		return ans;
	}
} f[maxN], t, lg;
int pw2[maxN];
void prepare() {
	C[0][0] = 1;
	for (int i = 1; i <= n; i++) {
		C[i][0] = 1;
		for (int j = 1; j <= i; j++)
			C[i][j] = (C[i - 1][j - 1] + C[i - 1][j]) % mod;
	}
}
signed main() {
	cin >> s >> n;
	reverse(s.begin(), s.end());
	m = s.size(), s = ' ' + s;
	t.resize(n);
	for (int i = 0; i < n; i++)
		cin >> t[i];
	pw2[0] = 1;
	for (int i = 1; i <= m; i++)
		pw2[i] = pw2[i - 1] * 2 % mod;
	f[0] = t;
	prepare();
	for (int i = 1; i <= m; i++) 
		f[i] = f[i - 1] - f[i - 1].shift(pw2[i - 1]);
	int ans = 0, nw = 0, sum = -1, coef = mod - 1;
	for (int i = m; i >= 1; i--) {
		if(s[i] == '1') {
			ans = (ans + coef * f[i - 1].queryp(nw) % mod) % mod;
			coef = mod - coef;
			nw = nw + pw2[i - 1], nw %= mod;
			sum = (nw - 1 + mod) % mod;
		}
	}
	if(sum != -1) {
		lg.resize(n + 2);
		inv[0] = inv[1] = 1;
		for (int i = 2; i <= n + 1; i++)
			inv[i] = (mod - mod / i) * inv[mod % i] % mod;
		lg[0] = t.queryp(0);
		for (int i = 1; i <= n + 1; i++)
			lg[i] = t.queryp(i), lg[i] = (lg[i - 1] + lg[i]) % mod;
		ans = (ans + lg.query(sum)) % mod;
		//cout << ans << endl;
	}
	cout << ans * (mod + 1) / 2 % mod << endl;
	return 0;
}
posted @ 2025-11-05 12:01  LUlululu1616  阅读(10)  评论(0)    收藏  举报