2025.04.23 CW 模拟赛 C. 徽章

C. 徽章

题面描述

Kaguya 是一个还没能辟谷的女孩子。

有一天,Kaguya 来到了食堂。食堂的队伍好长好长,居然长达 \(n\) 个同学。Kaguya 学过一点信息学,所以她将队伍中的同学依次编号为 \(1 \ldots n\)。其中,有 \(n\) 个区间 \([l_i, r_i]\) 引起了她的兴趣。

Kaguya 拿出了 \(m\) 个徽章,并将第 \(i\) (\(1 \leq i \leq m\)) 个徽章送给了第 \(x_i\) 个人。

Kaguya 不喜欢奇数。她希望知道,\([l_1, r_1] \ldots [l_n, r_n]\) 中,有多少区间 \([l, r]\) 满足:第 \(l\) 个人到第 \(r\) 个人得到的徽章数目总和是奇数。

由于 Kaguya 非常可爱,所以你需要回答她 \(q\) 次同样形式的询问。


思路

考虑一个最朴素的暴力.

  • 对于每次询问我们做一次前缀和, 枚举每一个区间计算贡献. 时间复杂度 \(\mathcal{O}(nq)\).

可以发现如果每次 \(m > \sqrt{n}\), 那么 \(q < \sqrt{n}\), 总的时间复杂度就是 \(\mathcal{O}(n \sqrt{n})\).

由此我们可以考虑根号分治. 对于 \(m > \sqrt{n}\) 的情况, 直接暴力即可. 接下来考虑 \(m \le \sqrt{n}\) 的情况.

我们考虑这样一个容斥:

pEIq4Hg.png

如图, 假设现在有 \(5\) 个点. 我们考虑包含 \(c\) 个徽章的区间的容斥系数. 首先对于 \(c = 1\) 的区间, 我们的系数是 \(1\). \(c = 2\) 时不合法区间在 \(c = 1\) 时一定被统计了至少 \(2\) 次, 所以容斥系数为 \(-2\). 同理可得, \(c = 3\) 时, 容斥系数为 \(2\), \(c = 4\) 时, 容斥系数为 \(-2\)...

所以容斥系数为 \(1, -2, 2, -2, 2, \dots\)

对于每一个用于容斥的区间 \([l, r]\) , 我们需要统计左端点 \(\le l\) 且右端点 \(\ge r\) 的区间的个数, 这是一个二维偏序问题, 使用主席树预处理即可.

所有用于容斥的区间个数是 \(\mathcal{O}(m^2)\) 级别的, 因为 \(m \le \sqrt{n}\), 所以总的时间复杂度应该为 \(\mathcal{O}(n \sqrt{n} \log n)\).

#include <iostream>
#include <algorithm>
#include <vector>

using namespace std;

char buf[1 << 20], *p1, *p2;
#define getchar() (p1 == p2 and (p2 = (p1 = buf) + fread(buf, 1, 1 << 20, stdin), p1 == p2) ? 0 : *p1++)

int read() {
    int x = 0; char c = getchar();
    while (c < '0' or c > '9') {
        c = getchar();
    }
    while (c >= '0' and c <= '9') {
        x = x * 10 + (c & 15);
        c = getchar();
    }
    return x;
}

constexpr int N = 500002, M = 15000001;
constexpr int B = 710;

int n, q, pre[N], rt[N], c[B], arr[N];
struct Node {
    int l, r;
    friend bool operator<(Node x, Node y) {
        return x.l < y.l;
    }
} a[N];

class SegmentTree {
private:
    int tot;
    int tr[M], ls[M], rs[M];

public:
    int build(int l, int r) {
        int u = ++tot;
        if (l == r) {
            return u;
        }
        int mid = (l + r) >> 1;
        ls[u] = build(l, mid);
        rs[u] = build(mid + 1, r);
        return u;
    }

    int update(int pre, int l, int r, int x) {
        int u = ++tot;
        tr[u] = tr[pre] + 1;
        ls[u] = ls[pre], rs[u] = rs[pre];
        if (l == r) {
            return u;
        }
        int mid = (l + r) >> 1;
        x <= mid ? ls[u] = update(ls[pre], l, mid, x) : rs[u] = update(rs[pre], mid + 1, r, x);
        return u;
    }

    int query(int x, int y, int l, int r, int k) {
        if (l == r) {
            return tr[y] - tr[x];
        }
        int mid = (l + r) >> 1, res;
        if (k <= mid) {
            res = tr[rs[y]] - tr[rs[x]];
            res += query(ls[x], ls[y], l, mid, k);
        }
        else {
            res = query(rs[x], rs[y], mid + 1, r, k);
        }
        return res;
    }

} seg;

void init() {
    c[1] = 1, c[2] = -2;
    for (int i = 3; i < B; ++i) {
        c[i] = -c[i - 1];
    }
    n = read(), q = read();
    for (int i = 1; i <= n; ++i) {
        a[i] = {read(), read()};
    }
}

void calculate() {
    rt[0] = seg.build(1, n);
    sort(a + 1, a + n + 1);
    for (int i = 1; i <= n; ++i) {
        arr[i] = a[i].l;
        rt[i] = seg.update(rt[i - 1], 1, n, a[i].r);
    }
    while (q--) {
        int m = read(), ans = 0;
        if (m >= B) {
            fill(pre, pre + n + 1, 0);
            for (int i = 1; i <= m; ++i) {
                int x = read();
                pre[x] ^= 1;
            }
            for (int i = 1; i <= n; ++i) {
                pre[i] ^= pre[i - 1];
            }
            for (int i = 1; i <= n; ++i) {
                ans += pre[a[i].r] ^ pre[a[i].l - 1];
            }
        }
        else {
            vector<int> x(m);
            for (int& i : x) {
                i = read();
            }
            sort(x.begin(), x.end());
            for (int i = 0; i < m; ++i) {
                int p = upper_bound(arr + 1, arr + n + 1, x[i]) - arr - 1;
                for (int j = i; j < m; ++j) {
                    ans += c[j - i + 1] * seg.query(rt[0], rt[p], 1, n, x[j]);
                }
            }
        }
        printf("%d\n", ans);
    }
}

void solve() {
    init();
    calculate();
}

int main() {
    solve();
    return 0;
}
posted @ 2025-04-23 20:55  Steven1013  阅读(48)  评论(0)    收藏  举报