2022牛客暑期多校第一场 H. Fly

2022牛客暑期多校第一场 H. Fly

题意

给出 \(a_1,a_2,\dots,a_n\),以及 \(k\) 个限制 \((b_1, c_1),(b_2, c_2),\dots,(b_k,c_k)\)\((b_i,c_i)\) 表示 \(x_{b_i}\) 的第 \(c_i\) 位(从低位向高位数,最低位为第 \(0\) 位)必须为 \(0\)

给定整数 \(M\),求满足 \(a_1x_1+a_2x_2+\cdots +a_nx_n \leq M\) 且满足上述 \(k\) 个限制的 \((x_1,x_2,\dots,x_n)\) 的方案数。

分析

注意到 \(\sum a_i\)\(n\) 同阶。

考虑按位拆成若干01背包,第 \(i\) 位的背包的考虑的物品为 \(a_1\cdot 2^i, a_2\cdot 2^i, \dots, a_n\cdot 2^i\)

朴素的01背包复杂度是 \(O(n^2)\) 的在这个问题来看复杂度过大,使用数学手段优化。

\(p(i,j)\) 为第 \(i\) 位背包取 \(j \cdot 2^i\) 的方案数。

不考虑限制的话,对于每个 \(i\)\(p(i,j)\) 都应该为 \(\prod\limits_{i=1}^n(1+x^{a_i})\) 展开式的 \(x^j\) 的系数,利用分治NTT可以得出不考虑限制的 \(p(i)\) 数组。复杂度为 \(O(n\log^2n)\)

此时对于每一位分别考虑限制,以第 \(i\) 位为例,针对已经得到的 \(p(i)\) 数组,可以考虑将01背包倒着做,以消除这一位那些不能使用的物品的贡献。这样每一位的背包 \(p(i,j)\) 就考虑完了,需要的复杂度是 \(O(nk)\)

接下来考虑如何合并这若干位背包。

定义 \(f(i,j)\) 表示已经合并 \((0,1,2...,i-1)\) 位背包,取了在如下区间范围的方案数

\[((j-1) \cdot 2^i+M\%2^i,j \cdot 2^i+M\%2^i] \]

定义辅助数组 \(h(i,j)\) 表示已经合并 \((0,1,2...,i-1)\) 位背包,取了在如下区间范围的方案数。

\[((j-1)\cdot2^{i-1}+M\%2^{i-1},j\cdot2^{i-1}+M\%2^{i-1}] \]

显然有 \(h(i,t)=\sum_{j+k=t} f(i-1,j)\cdot p(i-1,k)\)

考虑如何将 \(h\) 数组变换为 \(f\) 数组

  • (M>>(i-1))&1==1
    不妨设 \(t\) 是偶数
    \(h(i,t+1)\)\(h(i,t)\) 都应贡献到 \(f(i,\frac{t}{2})\)中。

    因为

    \(h(i,t+1)\) 所代表区间为

    \[\left(t\cdot2^{i-1}+M\%2^{i-1},(t+1)\cdot2^{i-1}+M\%2^{i-1}\right] \]

    可改写成

    \[\left(\frac{t}{2}\cdot2^{i}+M\%2^{i-1},\frac{t}{2}\cdot2^{i}+2^{i-1}+M\%2^{i-1}\right] \]

    \(h(i,t)\) 所代表区间为

    \[\left((t-1)\cdot2^{i-1}+M\%2^{i-1},t\cdot2^{i-1}+M\%2^{i-1}\right] \]

    可改写成

    \[\left(\left(\frac{t}{2}-1\right)\cdot2^{i}+2^{i-1}+M\%2^{i-1},\frac{t}{2}\cdot2^{i}+M\%2^{i-1}\right] \]

    两者可合并为

    \[\left(\left(\frac{t}{2}-1\right)\cdot2^{i}+2^{i-1}+M\%2^{i-1},\frac{t}{2}\cdot2^{i}+2^{i-1}+M\%2^{i-1}\right] \]

    显然当 (M>>(i-1))&1==1 时,\(2^{i-1}+M\%2^{i-1}=M\%2^i\)

    所以这种情况下 \(h(i,t+1)\)\(h(i,t)\) 所代表区间合并后的区间确实是 \(f(i,\frac{t}{2})\) 所代表的。

  • (M>>(i-1))&1==0
    不妨设 \(t\) 是偶数
    \(h(i,t)\)\(h(i,t-1)\) 都应贡献到 \(f(i,\frac{t}{2})\)中。

    因为

    \(h(i,t)\) 所代表区间为

    \[\left((t-1)\cdot2^{i-1}+M\%2^{i-1},t\cdot2^{i-1}+M\%2^{i-1}\right] \]

    可改写成

    \[\left(\left(\frac{t}{2}-1\right)\cdot2^{i}+2^{i-1}+M\%2^{i-1},\frac{t}{2}\cdot2^{i}+M\%2^{i-1}\right] \]

    \(h(i,t-1)\) 所代表区间为

    \[\left((t-2)\cdot2^{i-1}+M\%2^{i-1},(t-1)\cdot2^{i-1}+M\%2^{i-1}\right] \]

    可改写成

    \[\left(\left(\frac{t}{2}-1\right)\cdot2^{i}+M\%2^{i-1},\left(\frac{t}{2}-1\right)\cdot2^{i}+2^{i-1}+M\%2^{i-1}\right] \]

    两者可合并为

    \[\left(\left(\frac{t}{2}-1\right)\cdot2^{i}+M\%2^{i-1},\frac{t}{2}\cdot2^{i}+M\%2^{i-1}\right] \]

    显然当 (M>>(i-1))&1==0 时,\(M\%2^{i-1}=M\%2^i\)

    所以这种情况下 \(h(i,t)\)\(h(i,t-1)\) 所代表区间合并后的区间确实是 \(f(i,\frac{t}{2})\) 所代表的。

根据上面的理论,将 \(h\) 转化为对应 \(f\) 显然是 \(O(n)\) 的。

而由于这个式子 \(h(i,t)=\sum_{j+k=t} f(i-1,j)\cdot p(i-1,k)\),显然合并一次需要利用卷积进行优化,一次复杂度为 \(O(n \log n)\),一共要合并 \(O(\log M)\) 次,这部分复杂度为 \(O(n \log n\log M)\)

总复杂度为 \(O(n\log^2n + nk + n \log n\log M)\)

代码

#include <algorithm>
#include <iostream>
#include <set>
#include <vector>
using namespace std;
namespace NTT {
typedef int Lint;
typedef long long LLint;
// 2的幂次
const int maxn = (1 << 21) + 10;
const Lint mod = 998244353;
const Lint g = 3;
Lint fpow(Lint a, Lint b, Lint mod) {
    Lint res = 1;
    for (; b; b >>= 1) {
        if (b & 1)
            res = (LLint)res * a % mod;
        a = (LLint)a * a % mod;
    }
    return res;
}
inline Lint add(Lint a, Lint b) {
    a += b;
    return a >= mod ? a - mod : a;
}
inline Lint mul(Lint a, Lint b) {
    return (LLint)a * b % mod;
}
int r[maxn];
void cal_r(int n) {
    for (int i = 0; i < n; i++) {
        r[i] = 0;
        r[i] = (i & 1) * (n >> 1) + (r[i >> 1] >> 1);
    }
}
void dft(Lint* a, int n, int type) {
    for (int i = 0; i < n; i++)
        if (i < r[i])
            swap(a[i], a[r[i]]);
    for (int i = 1; i < n; i <<= 1) {
        int p = i << 1;
        Lint w = fpow(g, (mod - 1) / p, mod);
        if (type == -1)
            w = fpow(w, mod - 2, mod);
        for (int j = 0; j < n; j += p) {
            Lint t = 1;
            for (int k = 0; k < i; k++, t = mul(t, w)) {
                Lint tmp = mul(a[j + k + i], t);
                a[j + k + i] = add(a[j + k], mod - tmp);
                a[j + k] = add(a[j + k], tmp);
            }
        }
    }
    if (type == -1) {
        Lint inv = fpow(n, mod - 2, mod);
        for (int i = 0; i < n; i++)
            a[i] = mul(a[i], inv);
    }
}
Lint p[maxn], q[maxn];
vector<Lint> poly_mul(const vector<Lint>& a, const vector<Lint>& b) {
    vector<Lint> res;
    int n = a.size(), m = b.size();
    res.resize(n + m - 1);
    int len = n + m - 1;
    int lim = 1;
    while (lim < len)
        lim <<= 1;
    copy(a.begin(), a.end(), p);
    fill(p + n, p + lim, 0);
    copy(b.begin(), b.end(), q);
    fill(q + m, q + lim, 0);
    cal_r(lim);
    dft(p, lim, 1), dft(q, lim, 1);
    for (int i = 0; i < lim; i++)
        p[i] = mul(p[i], q[i]);
    dft(p, lim, -1);
    for (int i = 0; i < n + m - 1; i++)
        res[i] = p[i];
    return res;
}
};  // namespace NTT

typedef long long Lint;
const int maxn = 4e4 + 10;
int n, k;
Lint m;
int a[maxn];
vector<int> get_poly_mul(int l, int r) {
    if (l == r) {
        vector<int> res;
        res.resize(a[l] + 1);
        res[0] = 1;
        res[a[l]] = 1;
        return res;
    }
    int mid = l + r >> 1;
    return NTT::poly_mul(get_poly_mul(l, mid), get_poly_mul(mid + 1, r));
}
set<int> S[60];
int main() {
    ios::sync_with_stdio(0);
    cin.tie(0);
    cin >> n >> m >> k;
    for (int i = 1; i <= n; i++) {
        cin >> a[i];
    }
    vector<int> t = get_poly_mul(1, n);
    for (int i = 1; i <= k; i++) {
        int b, c;
        cin >> b >> c;
        S[c].insert(b);
    }
    vector<int> f(1);
    f[0] = 1;
    for (int i = 0; m; i++) {
        vector<int> p = t;
        for (int x : S[i]) {
            for (int j = 0; j + a[x] < t.size(); j++) {
                p[j + a[x]] = NTT::add(p[j + a[x]], NTT::mod - p[j]);
            }
        }
        f = NTT::poly_mul(f, p);
        vector<int> tmp;
        if (m & 1) {
            tmp.resize(((f.size() - 1) >> 1) + 1);
            for (int j = 0; j < f.size(); j++) {
                tmp[j >> 1] = NTT::add(tmp[j >> 1], f[j]);
            }
        } else {
            tmp.resize((f.size() >> 1) + 1);
            for (int j = 0; j < f.size(); j++) {
                tmp[j + 1 >> 1] = NTT::add(tmp[j + 1 >> 1], f[j]);
            }
        }
        f = move(tmp);
        m >>= 1;
    }
    cout << f[0] << '\n';
    return 0;
}
posted @ 2022-07-21 21:37  聆竹听风  阅读(196)  评论(0)    收藏  举报