HD 2025 春季联赛 5

01

DP 入门题

03

线性基 线段树

解题思路

解决这题就必须要学下异或线性基。

对于子树的询问,我们可以转化成在 DFS 序中的区间询问:区间的左端点就是子树的根在 DFS 序中的位置,区间长度就是子树的大小。于是我们要做的就是实现单点修改和区间查询(查询区间的异或线性基)。

直接线段树就好了(下文展示的线段树是左闭右开线段树)。

具体的,对于单点修改,我们将值插入线段树中所有包含该点的区间,实现就是在找这个点的过程中插入就好了:

void upd(int cur, int l, int r, int pos, int val) {
    insert(tr[cur], val); // 将值插入所有包含目标点pos的区间
    if (l + 1 == r) {
        return;
    }

    int m = r + l >> 1;
    if (pos < m) {
        upd(cur << 1, l, m, pos, val);
    }
    else {
        upd(cur << 1 | 1, m, r, pos, val);
    }
}

对于区间查询,我们只用实现两个线性基德合并就好了,我的做法是直接重载 + 号:

// 合并
std::array<int, L> operator+ (std::array<int, L> u, const std::array<int, L> &v) {
    for (int i = 0; i < L; i++) {
        if (v[i]) {
            insert(u, v[i]);
        }
    }
    return u;
}

std::array<int, L> quiry(int cur, int l, int r, int sl, int sr) {
    if (sl <= l && r <= sr) {
        return tr[cur];
    }

    int m = r + l >> 1;
    if (sr <= m) {
        return quiry(cur << 1, l, m, sl, sr);
    }
    else if (m <= sl) {
        return quiry(cur << 1 | 1, m, r, sl, sr);
    }
    else {
        return quiry(cur << 1, l, m, sl, sr) + quiry(cur << 1 | 1, m, r, sl, sr);
    }
}
CODE
std::vector<int> g[N + 5];

int siz[N + 5];
int pos[N + 5], tim;
void dfs(int cur, int fa) {
    siz[cur] = 1;
    pos[cur] = tim++;
    
    for (auto &to : g[cur]) {
        if (to == fa) {
            continue;
        }

        dfs(to, cur);
        siz[cur] += siz[to];
    }
}

void insert(std::array<int, L> &base, int val) {
    for (int i = L - 1; ~i; i--) {
        if (!(val >> i & 1)) {
            continue;
        }

        if (base[i]) {
            val ^= base[i];
        }
        else {
            for (int j = j - 1; ~j; j--) {
                if (val >> j & 1) {
                    val ^= base[j];
                }
            }
            for (int j = i + 1; j < L; j++) {
                if (base[j] >> i & 1) {
                        base[j] ^= val;
                }
            }
            base[i] = val;
            return;
        }
    }
}

std::array<int, L> operator+ (std::array<int, L> u, const std::array<int, L> &v) {
    for (int i = 0; i < L; i++) {
        if (v[i]) {
            insert(u, v[i]);
        }
    }
    return u;
}

std::array<int, L> tr[N << 2];
void build(int cur, int l, int r) {
    for (int i = 0; i < L; i++) {
        tr[cur][i] = 0;
    }
    if (l + 1 == r) {
        return;
    }

    int m = l + r >> 1;
    build(cur << 1, l, m);
    build(cur << 1 | 1, m, r);
}

void upd(int cur, int l, int r, int pos, int val) {
    insert(tr[cur], val);
    if (l + 1 == r) {
        return;
    }

    int m = r + l >> 1;
    if (pos < m) {
        upd(cur << 1, l, m, pos, val);
    }
    else {
        upd(cur << 1 | 1, m, r, pos, val);
    }
}

std::array<int, L> quiry(int cur, int l, int r, int sl, int sr) {
    if (sl <= l && r <= sr) {
        return tr[cur];
    }

    int m = r + l >> 1;
    if (sr <= m) {
        return quiry(cur << 1, l, m, sl, sr);
    }
    else if (m <= sl) {
        return quiry(cur << 1 | 1, m, r, sl, sr);
    }
    else {
        return quiry(cur << 1, l, m, sl, sr) + quiry(cur << 1 | 1, m, r, sl, sr);
    }
}

void solve()
{
    int n = 0, q = 0;
    std::cin >> n >> q;
    for (int i = 1; i <= n; i++) {
        g[i].clear();
        siz[i] = 0;
        tim = 0;
    }
    
    for (int i = 0; i < n - 1; i++) {
        int u = 0, v = 0;
        std::cin >> u >> v;
        g[u].push_back(v);
        g[v].push_back(u);
    }

    dfs(1, 0);
    build(1, 0, n);
    while (q--) {
        int op = 0, x = 0, y = 0;
        std::cin >> op >> x >> y;
        if (op == 1) {
            upd(1, 0, n, pos[x], y);
        }
        else {
            std::vector<int> base;
            for (auto &p : quiry(1, 0, n, pos[x], pos[x] + siz[x])) {
                if (p) {
                    base.push_back(p);
                }
            }
            int ans = 0;
            if (y >= (1 << base.size())) {
                ans = -1;
            }
            else {
                int i = 0;
                while (y) {
                    if (y & 1) {
                        ans ^= base[i];
                    }
                    i++;
                    y >>= 1;
                }
            }
            std::cout << ans << '\n';
        }
    }
}

04

数学 倍增 二进制

解题思路

我们可以先对序列中所有的数对 \(m\) 取模,那么得到的序列一定是以 \(m\) 为周期循环的。

单考虑区间 \([1, m]\),由于 \(m\) 很小,所以我们可以以 \(O(m^2)\) 的时间复杂度求得这个区间内子序列和为 \(x\)\(0 \leq x < m\))的方案数 \(f[x]\)。那么得到了这么一个区间的 \(f[x]\) 有什么用呢?我们可以由这个区间的 \(f[x]\) 求出长度为 \(2m\) 的区间的 \(f[x]\),进而可以求出所有长度是 \(2^im\) 的区间的 \(f[x]\)。然后我们可以从大到小贪心地选择已经计算过 \(f[x]\) 的区间来组成长度为 \(n\) 的区间,最后可能会剩下一段长度不足 \(m\) 的区间,单独考虑就好了。具体实现看代码:

CODE
std::vector<i64> base(int mx, int m) {
    std::vector f(2, std::vector(m, 0ll));
    f[0][0] = 1;
    for (int i = 1, u = 1; i <= mx; i++, u ^= 1) {
        for (int j = 0; j < m; j++) {
            f[u][(j + i) % m] = (f[u ^ 1][j] + f[u ^ 1][(j + i) % m]) % Mod;
        }
    }
    return f[mx & 1];
}

i64 solve()
{
    i64 n = 0, m = 0;
    std::cin >> n >> m;
    std::vector res(2, std::vector(m, 0ll));
    // 求最后一块长度不足 m 的区间
    res[0] = base(n % m, m);
    if (n < m) {
        return res[0][0] - 1;
    }

    n /= m;
    int len = std::log2(n);
    std::vector f(len + 1, std::vector(m, 0ll));
    f[0] = base(m, m); // 初始长度为 n 的区间
	// 递推至所有长度为 2^im 的区间
    for (int l = 1; l <= len; l++) {
        for (int i = 0; i < m; i++) {
            for (int j = 0; j < m; j++) {
                (f[l][(i + j) % m] += f[l - 1][i] * f[l - 1][j] % Mod) %= Mod;
            }
        }
    }

    int p = 0;
    // 贪心地选择区间
    for (int l = len; ~l && n; l--) {
        if ((1ll << l) > n) {
            continue;
        }

        n -= (1ll << l); 
        for (int i = 0; i < m; i++) {
            res[p ^ 1][i] = 0;
        }
        for (int i = 0; i < m; i++) {
            for (int j = 0; j < m; j++) {
                (res[p ^ 1][(i + j) % m] += res[p][i] * f[l][j] % Mod) %= Mod;
            }
        }
        p ^= 1;
    }
    // 最后减去空集
    return (res[p][0] - 1 + Mod) % Mod;
}

06

签到

07

大模拟,不如直接看标程序

08

数学

解题思路

最考验观察力的一集。

因为我们要选择两个不相交的序列,所以单个序列的长度最多为总长度的一半,当这个序列中全部都是最大值时,我们就构造出了值最大的序列。那么这一点有什么用呢?如果所有长度是总长度的一半的序列的数量大于这个最大值,则必然会有值相同的序列(应为所有序列的值都被限制在最大值之下)。所以解出 \(max \times {n \over 2} < C_{n}^{n \over 2}\) 得到的范围就是必然有解的范围,解出来是 \(24 \leq n\) 时必然有解,对于 \(n < 24\) 的情况,就直接暴力枚举。我的做法类似于状态压缩,把序列的长度和序列之和压缩成一个数放到 map 里面。

CODE
bool solve()
{
    int n = 0, m = 0;
    std::cin >> n >> m;
    std::vector a(n, 0);
    for (auto &i : a) {
        std::cin >> i;
    }
    if (n >= 24) {
        return true;
    }

    std::map<int, bool> vis;
    for (int i = 1; i < (1 << n); i++) {
        int sum = 0, cnt = 0;
        for (int j = 0; j < n; j++) {
            if (i >> j & 1) {
                sum += a[j];
                cnt++;
            }
        }
        sum |= (cnt << n);
        if (vis[sum]) {
            return true;
        }
        vis[sum] = true;
    }
    return false;
}

09

打表找规律

解题思路

直接根据题意打表就可以发现规律。

(打表程序是代码中注释的部分)

CODE
i64 qpow(i64 a, i64 p) {
    i64 res = 1;
    while (p) {
        if (p & 1) {
            (res *= a) %= Mod;
        }
        (a *= a) %= Mod;
        p >>= 1;
    }
    return res;
}
i64 inv(i64 a) {
    return qpow(a, Mod - 2);
}

void solve()
{
    int n = 0;
    std::cin >> n;
    int k = n / 5, m = n % 5;
    i64 ans = 0;
    if (m == 0 || m == 2) {
        ans = 1;
    }
    else if (m == 1) {
        ans = Mod - inv(qpow(2, std::max(1, k))) + 1;
    }
    else if (m == 3) {
        ans = Mod - inv(qpow(2, k + 2)) + 1; 
    }
    else {
        ans = Mod - inv(qpow(2, k + 1)) + 1;
    }
    std::cout << ans << '\n';
    // std::vector<double> p = { 1, 0.5, 1, 0.75, 0.5, 1, 0.5, 1 };
    // for (int i = 8; i < 100; i++) {
    //     p.push_back(0.5 * (std::max(p[i - 2], p[i - 5]) + std::max(p[i - 8], p[i - 5])));
    // }
    // for (int i = 0; i < 100; i++) {
    //     if (i % 5 == 0) {
    //         std::cout << "\n";
    //     }
    //     std::cout << i << '\t' <<  p[i] << '\n';
    // }
    // return;
}

10

注意到所有合法的三元组里面的点到原点的距离奇偶性相同就好了

CODE
std::vector<std::array<int, 2>> g[N + 5];
std::array<int, 2> cnt;

void dfs(int cur, int fa, int d) {
    cnt[d]++;
    for (auto &[to, val] : g[cur]) {
        if (to != fa) {
            dfs(to, cur, d ^ val);
        }
    }
}

void solve()
{
    int n = 0;
    std::cin >> n;
    
    cnt[0] = cnt[1] = 0;
    for (int i = 1; i <= n; i++)  {
        g[i].clear();
    }

    for (int i = 0; i < n - 1; i++) {
        int u = 0, v = 0, w = 0;
        std::cin >> u >> v >> w;
        w %= 2;
        g[u].push_back({ v, w });
        g[v].push_back({ u, w });
    }

    dfs(1, 0, 0);
    std::cout << 1ll * cnt[0] * cnt[0] * cnt[0] + 1ll * cnt[1] * cnt[1] * cnt[1] << '\n';
    return;
}
posted @ 2025-04-11 20:56  Young_Cloud  阅读(15)  评论(0)    收藏  举报