2022牛客暑期多校第一场 H. Fly
2022牛客暑期多校第一场 H. Fly
题意
给出 \(a_1,a_2,\dots,a_n\),以及 \(k\) 个限制 \((b_1, c_1),(b_2, c_2),\dots,(b_k,c_k)\),\((b_i,c_i)\) 表示 \(x_{b_i}\) 的第 \(c_i\) 位(从低位向高位数,最低位为第 \(0\) 位)必须为 \(0\)。
给定整数 \(M\),求满足 \(a_1x_1+a_2x_2+\cdots +a_nx_n \leq M\) 且满足上述 \(k\) 个限制的 \((x_1,x_2,\dots,x_n)\) 的方案数。
分析
注意到 \(\sum a_i\) 与 \(n\) 同阶。
考虑按位拆成若干01背包,第 \(i\) 位的背包的考虑的物品为 \(a_1\cdot 2^i, a_2\cdot 2^i, \dots, a_n\cdot 2^i\)。
朴素的01背包复杂度是 \(O(n^2)\) 的在这个问题来看复杂度过大,使用数学手段优化。
设 \(p(i,j)\) 为第 \(i\) 位背包取 \(j \cdot 2^i\) 的方案数。
不考虑限制的话,对于每个 \(i\),\(p(i,j)\) 都应该为 \(\prod\limits_{i=1}^n(1+x^{a_i})\) 展开式的 \(x^j\) 的系数,利用分治NTT可以得出不考虑限制的 \(p(i)\) 数组。复杂度为 \(O(n\log^2n)\)
此时对于每一位分别考虑限制,以第 \(i\) 位为例,针对已经得到的 \(p(i)\) 数组,可以考虑将01背包倒着做,以消除这一位那些不能使用的物品的贡献。这样每一位的背包 \(p(i,j)\) 就考虑完了,需要的复杂度是 \(O(nk)\)
接下来考虑如何合并这若干位背包。
定义 \(f(i,j)\) 表示已经合并 \((0,1,2...,i-1)\) 位背包,取了在如下区间范围的方案数
定义辅助数组 \(h(i,j)\) 表示已经合并 \((0,1,2...,i-1)\) 位背包,取了在如下区间范围的方案数。
显然有 \(h(i,t)=\sum_{j+k=t} f(i-1,j)\cdot p(i-1,k)\)
考虑如何将 \(h\) 数组变换为 \(f\) 数组
-
当
(M>>(i-1))&1==1
不妨设 \(t\) 是偶数
\(h(i,t+1)\) 和 \(h(i,t)\) 都应贡献到 \(f(i,\frac{t}{2})\)中。因为
\(h(i,t+1)\) 所代表区间为
\[\left(t\cdot2^{i-1}+M\%2^{i-1},(t+1)\cdot2^{i-1}+M\%2^{i-1}\right] \]可改写成
\[\left(\frac{t}{2}\cdot2^{i}+M\%2^{i-1},\frac{t}{2}\cdot2^{i}+2^{i-1}+M\%2^{i-1}\right] \]\(h(i,t)\) 所代表区间为
\[\left((t-1)\cdot2^{i-1}+M\%2^{i-1},t\cdot2^{i-1}+M\%2^{i-1}\right] \]可改写成
\[\left(\left(\frac{t}{2}-1\right)\cdot2^{i}+2^{i-1}+M\%2^{i-1},\frac{t}{2}\cdot2^{i}+M\%2^{i-1}\right] \]两者可合并为
\[\left(\left(\frac{t}{2}-1\right)\cdot2^{i}+2^{i-1}+M\%2^{i-1},\frac{t}{2}\cdot2^{i}+2^{i-1}+M\%2^{i-1}\right] \]显然当
(M>>(i-1))&1==1
时,\(2^{i-1}+M\%2^{i-1}=M\%2^i\)。所以这种情况下 \(h(i,t+1)\) 和 \(h(i,t)\) 所代表区间合并后的区间确实是 \(f(i,\frac{t}{2})\) 所代表的。
-
当
(M>>(i-1))&1==0
不妨设 \(t\) 是偶数
\(h(i,t)\) 和 \(h(i,t-1)\) 都应贡献到 \(f(i,\frac{t}{2})\)中。因为
\(h(i,t)\) 所代表区间为
\[\left((t-1)\cdot2^{i-1}+M\%2^{i-1},t\cdot2^{i-1}+M\%2^{i-1}\right] \]可改写成
\[\left(\left(\frac{t}{2}-1\right)\cdot2^{i}+2^{i-1}+M\%2^{i-1},\frac{t}{2}\cdot2^{i}+M\%2^{i-1}\right] \]\(h(i,t-1)\) 所代表区间为
\[\left((t-2)\cdot2^{i-1}+M\%2^{i-1},(t-1)\cdot2^{i-1}+M\%2^{i-1}\right] \]可改写成
\[\left(\left(\frac{t}{2}-1\right)\cdot2^{i}+M\%2^{i-1},\left(\frac{t}{2}-1\right)\cdot2^{i}+2^{i-1}+M\%2^{i-1}\right] \]两者可合并为
\[\left(\left(\frac{t}{2}-1\right)\cdot2^{i}+M\%2^{i-1},\frac{t}{2}\cdot2^{i}+M\%2^{i-1}\right] \]显然当
(M>>(i-1))&1==0
时,\(M\%2^{i-1}=M\%2^i\)。所以这种情况下 \(h(i,t)\) 和 \(h(i,t-1)\) 所代表区间合并后的区间确实是 \(f(i,\frac{t}{2})\) 所代表的。
根据上面的理论,将 \(h\) 转化为对应 \(f\) 显然是 \(O(n)\) 的。
而由于这个式子 \(h(i,t)=\sum_{j+k=t} f(i-1,j)\cdot p(i-1,k)\),显然合并一次需要利用卷积进行优化,一次复杂度为 \(O(n \log n)\),一共要合并 \(O(\log M)\) 次,这部分复杂度为 \(O(n \log n\log M)\)。
总复杂度为 \(O(n\log^2n + nk + n \log n\log M)\)
代码
#include <algorithm>
#include <iostream>
#include <set>
#include <vector>
using namespace std;
namespace NTT {
typedef int Lint;
typedef long long LLint;
// 2的幂次
const int maxn = (1 << 21) + 10;
const Lint mod = 998244353;
const Lint g = 3;
Lint fpow(Lint a, Lint b, Lint mod) {
Lint res = 1;
for (; b; b >>= 1) {
if (b & 1)
res = (LLint)res * a % mod;
a = (LLint)a * a % mod;
}
return res;
}
inline Lint add(Lint a, Lint b) {
a += b;
return a >= mod ? a - mod : a;
}
inline Lint mul(Lint a, Lint b) {
return (LLint)a * b % mod;
}
int r[maxn];
void cal_r(int n) {
for (int i = 0; i < n; i++) {
r[i] = 0;
r[i] = (i & 1) * (n >> 1) + (r[i >> 1] >> 1);
}
}
void dft(Lint* a, int n, int type) {
for (int i = 0; i < n; i++)
if (i < r[i])
swap(a[i], a[r[i]]);
for (int i = 1; i < n; i <<= 1) {
int p = i << 1;
Lint w = fpow(g, (mod - 1) / p, mod);
if (type == -1)
w = fpow(w, mod - 2, mod);
for (int j = 0; j < n; j += p) {
Lint t = 1;
for (int k = 0; k < i; k++, t = mul(t, w)) {
Lint tmp = mul(a[j + k + i], t);
a[j + k + i] = add(a[j + k], mod - tmp);
a[j + k] = add(a[j + k], tmp);
}
}
}
if (type == -1) {
Lint inv = fpow(n, mod - 2, mod);
for (int i = 0; i < n; i++)
a[i] = mul(a[i], inv);
}
}
Lint p[maxn], q[maxn];
vector<Lint> poly_mul(const vector<Lint>& a, const vector<Lint>& b) {
vector<Lint> res;
int n = a.size(), m = b.size();
res.resize(n + m - 1);
int len = n + m - 1;
int lim = 1;
while (lim < len)
lim <<= 1;
copy(a.begin(), a.end(), p);
fill(p + n, p + lim, 0);
copy(b.begin(), b.end(), q);
fill(q + m, q + lim, 0);
cal_r(lim);
dft(p, lim, 1), dft(q, lim, 1);
for (int i = 0; i < lim; i++)
p[i] = mul(p[i], q[i]);
dft(p, lim, -1);
for (int i = 0; i < n + m - 1; i++)
res[i] = p[i];
return res;
}
}; // namespace NTT
typedef long long Lint;
const int maxn = 4e4 + 10;
int n, k;
Lint m;
int a[maxn];
vector<int> get_poly_mul(int l, int r) {
if (l == r) {
vector<int> res;
res.resize(a[l] + 1);
res[0] = 1;
res[a[l]] = 1;
return res;
}
int mid = l + r >> 1;
return NTT::poly_mul(get_poly_mul(l, mid), get_poly_mul(mid + 1, r));
}
set<int> S[60];
int main() {
ios::sync_with_stdio(0);
cin.tie(0);
cin >> n >> m >> k;
for (int i = 1; i <= n; i++) {
cin >> a[i];
}
vector<int> t = get_poly_mul(1, n);
for (int i = 1; i <= k; i++) {
int b, c;
cin >> b >> c;
S[c].insert(b);
}
vector<int> f(1);
f[0] = 1;
for (int i = 0; m; i++) {
vector<int> p = t;
for (int x : S[i]) {
for (int j = 0; j + a[x] < t.size(); j++) {
p[j + a[x]] = NTT::add(p[j + a[x]], NTT::mod - p[j]);
}
}
f = NTT::poly_mul(f, p);
vector<int> tmp;
if (m & 1) {
tmp.resize(((f.size() - 1) >> 1) + 1);
for (int j = 0; j < f.size(); j++) {
tmp[j >> 1] = NTT::add(tmp[j >> 1], f[j]);
}
} else {
tmp.resize((f.size() >> 1) + 1);
for (int j = 0; j < f.size(); j++) {
tmp[j + 1 >> 1] = NTT::add(tmp[j + 1 >> 1], f[j]);
}
}
f = move(tmp);
m >>= 1;
}
cout << f[0] << '\n';
return 0;
}