题解:loj6878 生不逢时

写这篇题解的时候回酒店电脑崩了直接没了,也算是照应标题生不逢时了……

牛牛题。

题意:给定正整数 n, m 和 n 个区间,第 i 个区间为 \([l _ i, r _ i]\),保证 \(0 \leq l_i \leq r_i < 2^m\)

对于非负整数 \(x\),记 \(S _ m(x)\) 表示 \(x\) 在二进制下最低的 \(m\) 位依次连接成的 \(01\) 串,如果不足 \(m\) 位则在高位补 \(0\)

对于 \(k = 1, 2, \cdots ,n\),求有多少非负整数序列 \(a _ 1, a _ 2, \cdots, a _ k\) 满足下列条件。

  • 对于所有 \(1 \leq i \leq k,l _ i \leq a _ i \leq r _ i\)
  • \(S _ m(a _ 1 \oplus a _ 2 \oplus \cdots \oplus a _ k)\) 是回文串,其中 \oplus 表示按位异或运算。

答案对 \(998244353\) 取模。

做法:

首先考虑处理回文,直接考虑把高的一半位对称下来,那么这样就要求异或和为 \(0\) 且我们以后可以不用再考虑高位的事情。

然后就是考虑把 \([l_i,r_i]\) 分解成 $\log $ 个区间,这些区间形如 \([a_1,a_1+2^{d_1}),[a_2,a_2+ 2^{d_2})\cdots\),并且满足 \(a_i+2^{d_i}=a_{i+1}\)\(d_i\) 为最大的满足 \(2^{d_i}|a_i,a_i+2^{d_i}\le r_i+1\)。发现这些区间在把高位对称下来后还是区间,所以也可以不用管高位的事情了。

注意到这些区间满足前若干位固定,后若干位随意,我们这里用 \((x,l)\) 来指代一个固定位为 \(x\),后 \(l\) 位的区间,所以 \([l_i,r_i]\) 就可以被分解为若干个 \((a_j,d_j)\)

考虑如何合并两个区间的结果,假设为 \((x_1,l_1),(x_2,l_2)\)。发现 \(l\) 应该取 \(\max(l_1,l_2)\),因为短的那一边随便取,长的那一边一定可以任意取。而 \(x\) 就应该取 \(x_1,x_2\) 将后 \(l\) 位置成 \(0\) 后异或在一起,也比较显然。

根据上面这个东西可以直接列出来一个暴力的 dp,\(dp_{i,j,S}\) 代表考虑了前 \(i\) 个,目前得到结果为 \((S, j)\)。转移就是枚举一个区间然后直接合并即可,可以做到 \(O(\log^n V)\)

但是我们考虑一个事情,我们还有另外一种计算的方式,我们可以枚举 \(j\) 再考虑对应的 \(S\) 有多少种。观察到,对于一个 \([l_i,r_i]\),其区间应该是一直加 lowbit 然后再补齐 \(r\) 中少的。对于一个 \(j\),前后两个阶段中只会至多分别有一种 \(S\) 出现。所以如果最后答案 \((x,l)\)\(l=j\),则 \(x\) 其实只有 \(2^n\) 种取法!所以我们直接对于每个位置 \(i\),后 \(j\) 位随意的时暴力算出来高位的情况,然后只要求有多少种方案异或和为 \(0\) 即可。

所以我们现在只要干这么一些事情:

  1. 把所有区间分解一下。

  2. 注意到两个数异或卷积后再末尾置零和置零后再异或没区别,同时注意到补 lowbit 和补 \(r\) 两部分数其实是子集关系,也就是说,在 Trie 树上形如两条链。所以我们卷积的时候直接卷就可以。

  3. 因为我们要求的是 \(\max = j\) 的情况,所以考虑直接枚举确定位为全 \(0\) 的情况即可。

然后直接上一个 meet in middle 就可以做到 \(O(2^{\frac n 2}m)\) 的复杂度,还需要乘上枚举每个区间的 \(n\)。主要就是第三步的时候等于我要合并两个部分,枚举两侧固定为都是 \(x\),然后对于 \(\max\) 的解决考虑容斥,考虑在子树内选两个点的方案数再减去都选在儿子里没有选 \(j\) 的情况即可。细节可以见代码是怎么用类似 Trie 的东西维护这个。

联考中提到了另一种做法,就是把区间差分后改成只有上界的问题。现在主要是怎么用 \(O(m)\) 的时间解决新加入一个区间的问题。注意到,在只有上界的情况下,对于一个 \(j\) 取到 \(\max\) 时,所有数都是固定的,所以直接做可以得到一次 \(O(nm)\) 的复杂度。但是我们已经知道了前面的答案,我们就可以直接做一个前缀和,然后枚举 \(\max\) 再容斥掉没有取到 \(\max\) 的方案数,这样是 \(O(m)\) 的。

代码:

#include <bits/stdc++.h>
using namespace std;
const int maxn = 45, N = 1.2e8 + 5, mod = 998244353;
int n;
long long l[maxn], r[maxn], m, all, All, inv[maxn * 2];
long long rev(long long v, int l, int r) {
    long long res = 0;

    for (int i = l; i <= r; i++)
        res <<= 1, res |= ((v >> i) & 1);

    return res;
}
struct node {
    int son[2];
    long long val;
} ;
struct Trie {
    node tr[N];
    queue<int> ers;
    inline int newnode() {
        int p;

        if (!ers.empty())
            p = ers.front(), ers.pop();
        else
            p = ++tot;

        return p;
    }
    int tot, rt1, rt2, p, q;
    void add(long long l, long long r, long long x, long long y, int &t) {
        if (!t)
            t = newnode();

        //  cout << l << " " << r << " " << x << " " << y << endl;
        tr[t].val += y - x + 1, tr[t].val %= mod;

        //  cout << l << " " << r << " " << tr[t].val << endl;
        if (x <= l && r <= y)
            return ;

        long long mid = l + r >> 1;

        if (x <= mid)
            add(l, mid, x, y, tr[t].son[0]);

        if (mid < y)
            add(mid + 1, r, x, y, tr[t].son[1]);
    }
    void mrg(int x, int y, int &rt) {
        if (!x || !y)
            return ;

        if (!rt)
            rt = newnode();

        tr[rt].val = (tr[rt].val + 1ll * tr[x].val * tr[y].val % mod) % mod;

        for (int xx = 0; xx <= 1; xx++)
            for (int yy = 0; yy <= 1; yy++)
                mrg(tr[x].son[xx], tr[y].son[yy], tr[rt].son[xx ^ yy]);
    }
    void update(long long l, long long r, long long x, long long y, int &rt, int d) {
        //  if(l != r)
        //      cout << l << " " << r << " " << x << " " << y << endl;
        if (x <= l && r <= y) {
            long long w = (l & all) ^ rev(l, (m + 1) / 2, m - 1);
            //  cout << l << " " << r << " " << x << " " << y << " " << w << " " << d << " " << rev(l, m / 2, m - 1) << " " << l << " " << m / 2 + 1 << endl;
            w = (w >> d) << d;
            //  cout << w << " " << d << endl;
            add(0, all, w, w + (1ll << d) - 1, rt);
            return ;
        }

        long long mid = l + r >> 1;

        if (x <= mid)
            update(l, mid, x, y, rt, d - 1);

        if (mid < y)
            update(mid + 1, r, x, y, rt, d - 1);
    }
    int query(long long l, long long r, int p, int q, int d) {
        //  cout << l << " " << r << " " << p << " " << q << " " << tr[p].val << " " << tr[q].val << " " << tr[p].son[0] << " " << tr[q].son[0] << endl;
        if (!p || !q)
            return 0;

        long long mid = l + r >> 1;
        return ((query(l, mid, tr[p].son[0], tr[q].son[0], d - 1) + query(mid + 1, r, tr[p].son[1], tr[q].son[1],
                 d - 1)) % mod +
                (1ll * tr[p].val * tr[q].val % mod -
                 1ll * (tr[tr[p].son[0]].val + tr[tr[p].son[1]].val) % mod * (tr[tr[q].son[0]].val + tr[tr[q].son[1]].val) %
                 mod) * inv[d + 1] % mod + mod) % mod;
    }
    void clear(int u) {
        if (tr[u].son[0])
            clear(tr[u].son[0]);

        if (tr[u].son[1])
            clear(tr[u].son[1]);

        tr[u].son[0] = tr[u].son[1] = tr[u].val = 0;
    }
} tree;
signed main() {
    cin >> n >> m;
    inv[0] = 1;

    for (int i = 1; i <= m; i++)
        inv[i] = inv[i - 1] * (mod + 1) / 2 % mod;

    for (int i = 1; i <= n; i++)
        cin >> l[i] >> r[i];

    all = (1ll << m / 2) - 1, All = (1ll << m) - 1;
    tree.add(0, all, 0, 0, tree.rt1), tree.add(0, all, 0, 0, tree.rt2);

    for (int i = 1; i <= n; i++) {
        tree.p = tree.q = 0;
        swap(tree.rt1, tree.rt2);
        //  cout << 123 << endl;
        tree.update(0, All, l[i], r[i], tree.p, m);
        //  cout << 123 << endl;
        tree.mrg(tree.rt1, tree.p, tree.q);
        swap(tree.rt1, tree.q);
        tree.clear(tree.p), tree.clear(tree.q);
        cout << tree.query(0, all, tree.rt1, tree.rt2, m / 2 - 1) << endl;
    }

    return 0;
}
posted @ 2025-11-07 22:14  LUlululu1616  阅读(3)  评论(0)    收藏  举报