分治杂记

分治杂记

分治(Divide and Conquer),就是把一个复杂的问题分成若干子问题,分别求解。本质是缩小了问题的规模。

普通的分治

[ABC373G] No Cross Matching

给定平面上的 \(n\) 个黑点和 \(n\) 个白点,构造一种方案,将黑白点两两匹配并连线段,使得任意两条线段不相交。

\(n \leq 100\) ,保证无三点共线,保证有解

找到最左下角的黑点,找到一个白点使得这两个点连线的两边内黑白点数量分别相等,然后分治即可。

找白点可以用极角排序,时间复杂度 \(O(n^2 \log n)\)

#include <bits/stdc++.h>
using namespace std;
const int N = 3e2 + 7;

struct Point {
    int x, y, id;
} p[N << 1];

int match[N];

int n;

void solve(int l, int r) {
    if (l > r)
        return;

    sort(p + l, p + r + 1, [](const Point &a, const Point &b) { return a.x < b.x; });

    for (int i = l + 1; i <= r; ++i)
        p[i].x -= p[l].x, p[i].y -= p[l].y;

    sort(p + l + 1, p + r + 1, [](const Point &a, const Point &b) {
        return atan2(a.x, a.y) > atan2(b.x, b.y);
    });

    for (int i = l + 1, cnt = 0; i <= r; cnt += (p[i++].id <= n ? 1 : -1))
        if (!cnt && ((p[i].id <= n) ^ (p[l].id <= n))) {
            match[min(p[l].id, p[i].id)] = max(p[l].id, p[i].id);
            solve(l + 1, i - 1), solve(i + 1, r);
            return;
        }
}

signed main() {
    scanf("%d", &n);

    for (int i = 1; i <= n * 2; ++i)
        scanf("%d%d", &p[i].x, &p[i].y), p[i].id = i;

    solve(1, n * 2);

    for (int i = 1; i <= n; ++i)
        printf("%d ", match[i] - n);

    return 0;
}

一维点对

主要就是讨论点对分别在左右区间产生的贡献。

CF459D Pashmak and Parmida's problem

给定 \(a_{1 \sim n}\) ,设 \(f(l, r, x)\) 表示 \(x\)\(a_{l \sim r}\) 中的出现次数。求有多少对 \(i < j\) 满足 \(f(1, i, a_i) > f(j, n, a_j)\)

\(n \leq 10^6\)

预处理 \(f(1, i, a_i)\)\(f(i, n, a_i)\) 后就变为逆序对问题,可以归并做到 \(O(n \log n)\) ,比 DS 好写多了。

#include <bits/stdc++.h>
typedef long long ll;
using namespace std;
const int N = 1e6 + 7;

map<int, int> mp;

int a[N], pre[N], suf[N];

ll ans;
int n;

void solve(int l, int r) {
    if (l == r)
        return;

    int mid = (l + r) >> 1;
    solve(l, mid), solve(mid + 1, r);

    for (int i = l, j = mid; i <= mid; ++i) {
        while (j < r && pre[i] > suf[j + 1])
            ++j;

        ans += j - mid;
    }

    inplace_merge(pre + l, pre + mid + 1, pre + r + 1);
    inplace_merge(suf + l, suf + mid + 1, suf + r + 1);
}

signed main() {
    scanf("%d", &n);

    for (int i = 1; i <= n; ++i)
        scanf("%d", a + i), pre[i] = ++mp[a[i]];

    mp.clear();

    for (int i = n; i; --i)
        suf[i] = ++mp[a[i]];

    solve(1, n);
    printf("%lld", ans);
    return 0;
}

友好点对

给出二维平面上的 \(n\) 个不同点。若存在一个矩形只包含两个点(包括边界),则称这两个点为友好点对。求友好点对的数量。

\(n \leq 10^5\)

不难发现若两个点是友好点对,则最优的矩形一定是使得这两个点在对角线两端。

考虑按横坐标分治。以左下-右上为例,左上-右下是对称的。处理跨过中点的贡献时,固定左侧的某个点,则右侧能选择的点是一个上升序列。分别对左右区间维护单调栈即可。

P7883 平面最近点对(加强加强版)

给定 \(n\) 个二维平面上的点,求最近点对距离的平方值。

\(n \leq 4 \times 10^5\)

考虑按 \(x\) 坐标分治,记分治的分界线为 \(mid\) ,左右区间内部的最近点对距离的较小值为 \(d\)

考虑处理左右区间之间对答案的贡献,首先发现一个点 \((x, y)\) 满足 \(|x - mid| \leq d\) 时才有可能更新答案。将这些点拿出来,并按 \(y\) 坐标排序。对于其中的一个点 \((x, y)\) ,可能更新答案的点 \((x', y')\) 必然满足 \(|y - y'| \leq d\) ,那么只要找到附近满足该条件的点更新即可。直觉可以发现这样的点是常数级别的,不然内部就可以更新了。

时间复杂度 \(O(n \log n)\)

#include <bits/stdc++.h>
typedef long long ll;
using namespace std;
const ll inf = 1e18;
const int N = 4e5 + 7;

struct Point {
    int x, y;
} p[N], tmp[N];

int n;

inline ll dist(const Point &a, const Point &b) {
    return 1ll * (a.x - b.x) * (a.x - b.x) + 1ll * (a.y - b.y) * (a.y - b.y);
}

inline ll solve(int l, int r) {
    if (l == r)
        return inf;
        
    int mid = (l + r) >> 1, midx = p[mid].x;
    ll d = min(solve(l, mid), solve(mid + 1, r));
    inplace_merge(p + l, p + mid + 1, p + r + 1, 
        [](const Point &a, const Point &b) { return a.y < b.y; });
    int len = 0;
    
    for (int i = l; i <= r; ++i)
        if (1ll * (midx - p[i].x) * (midx - p[i].x) < d)
            tmp[++len] = p[i];
    
    for (int i = 1; i <= len; ++i)
        for (int j = i + 1; j <= len && 1ll * (tmp[j].y - tmp[i].y) * (tmp[j].y - tmp[i].y) < d; ++j)
            d = min(d, dist(tmp[i], tmp[j]));
    
    return d;
}

signed main() {
    scanf("%d", &n);
    
    for (int i = 1; i <= n; ++i)
        scanf("%d%d", &p[i].x, &p[i].y);
    
    stable_sort(p + 1, p + 1 + n, [](const Point &a, const Point &b) { return a.x < b.x; });
    printf("%lld", solve(1, n));
    return 0;
}

CF429D Tricky Function

给定 \(a_{1 \sim n}\) ,令 \(f(i, j) = (i - j)^2 + g^2(i, j)\) ,其中 \(g(i, j) = \sum_{k = \min(i, j) + 1}^{\max(i, j)} a_k\) ,求 \(\min_{1 \leq i < j \leq n} f(i, j)\)

\(n \leq 10^5\)

首先可以发现 \(f(i, j) = f(j, i)\) ,钦定 \(j < i\) ,则 \(g(i, j) = s_i - s_j\) ,其中 \(s\) 为前缀和,因此 \(f(i, j) = (i - j)^2 + (s_i - s_j)^2\)

发现这个形式很像两点距离公式,问题转化为给定平面上 \(n\) 个点 \((i, s_i)\) ,求最近点对。不难用分治做到 \(O(n \log n)\)

#include <bits/stdc++.h>
typedef long long ll;
using namespace std;
const ll inf = 1e18;
const int N = 1e5 + 7;

struct Point {
    int x, y;
} p[N], tmp[N];

int n;

inline ll dist(const Point &a, const Point &b) {
    return 1ll * (a.x - b.x) * (a.x - b.x) + 1ll * (a.y - b.y) * (a.y - b.y);
}

inline ll solve(int l, int r) {
    if (l == r)
        return inf;
        
    int mid = (l + r) >> 1, midx = p[mid].x;
    ll d = min(solve(l, mid), solve(mid + 1, r));
    inplace_merge(p + l, p + mid + 1, p + r + 1, 
        [](const Point &a, const Point &b) { return a.y < b.y; });
    int len = 0;
    
    for (int i = l; i <= r; ++i)
        if (1ll * (midx - p[i].x) * (midx - p[i].x) < d)
            tmp[++len] = p[i];
    
    for (int i = 1; i <= len; ++i)
        for (int j = i + 1; j <= len && 1ll * (tmp[j].y - tmp[i].y) * (tmp[j].y - tmp[i].y) < d; ++j)
            d = min(d, dist(tmp[i], tmp[j]));
    
    return d;
}

signed main() {
    scanf("%d", &n);
    
    for (int i = 1; i <= n; ++i)
        scanf("%d", &p[i].y), p[i].x = i, p[i].y += p[i - 1].y;
    
    printf("%lld", solve(1, n));
    return 0;
}

基于中点的序列分治

主要就是讨论跨过中点的区间的贡献,一般的维护信息方式是从中间向两边扩展。

CF549F Yura and Developers

给定数组 \(a_{1 \sim n}\) 以及模数 \(k\) ,求满足以下条件的区间 \([l, r]\) 的数量:

  • \(r - l + 1 \geq 2\)
  • \(\sum_{i = l}^r a_i \equiv \max_{i = l}^r a_i \pmod{k}\)

\(n \leq 3 \times 10^5\)\(k \leq 10^6\)

考虑分治,讨论如何计算跨过中点的贡献。先求出以 \(i\) 为左端点且区间最大值出现在左区间的所有右端点的答案,可以用指针维护可行右端点的区间。那么得到 \(s_r \equiv mx + s_{i - 1}\) ,直接用桶存一下即可。

右区间同理,注意不能统计左区间也有最大值的情况,时间复杂度 \(O(n \log n)\)

#include <bits/stdc++.h>
typedef long long ll;
using namespace std;
const int N = 3e5 + 7, K = 1e6 + 7;

int a[N], s[N], cnt[K];

ll ans;
int n, k;

void solve(int l, int r) {
    if (l == r)
        return;
    
    int mid = (l + r) >> 1;
    solve(l, mid), solve(mid + 1, r);
    int j = mid + 1;
    
    for (int i = mid, mx = a[i]; i >= l; mx = max(mx, a[--i])) {
        for (; j <= r && a[j] <= mx; ++j)
            ++cnt[s[j]];
        
        ans += cnt[(s[i - 1] + mx) % k];
    }
    
    for (--j; j > mid; --j)
        --cnt[s[j]];
    
    for (int i = mid + 1, mx = a[i]; i <= r; mx = max(mx, a[++i])) {
        for (; j >= l && a[j] < mx; --j)
            ++cnt[s[j - 1]];
        
        ans += cnt[(s[i] - mx % k + k) % k];
    }

    for (++j; j <= mid; ++j)
        --cnt[s[j - 1]];
}

signed main() {
    scanf("%d%d", &n, &k);
    
    for (int i = 1; i <= n; ++i)
        scanf("%d", a + i), s[i] = (s[i - 1] + a[i]) % k;
    
    solve(1, n);
    printf("%lld", ans);
    return 0;
}

P8317 [FOI2021] 幸运区间

给出长度为 \(n\) 的序列,每个位置上有 \(d\) 个数。需要选出 \(k\) 个数。

如果一个区间内每个位置上的 \(d\) 个数至少有一个出现在选出的 \(k\) 个数中,则是一个幸运区间。

求最长的幸运区间。

\(n \leq 10^5\)\(d \leq 4\)\(k \leq 3\)

考虑分治,每次从中点开始,暴力向两端扩展,不能扩展时加入一个数继续扩展。由于 \(d, k\) 都很小,加入的数直接暴搜即可,时间复杂度 \(O(nd \log n + nd^k)\)

#include <bits/stdc++.h>
using namespace std;
const int N = 1e5 + 7, D = 5;

int a[N][D];
bool chose[N];

int n, d, k, ansl, ansr;

void dfs(const int L, const int R, int cnt, int l, int r) {
    while (l > L) {
        bool flag = false;

        for (int i = 1; i <= d; ++i)
            flag |= chose[a[l - 1][i]];

        if (flag)
            --l;
        else
            break;
    }

    while (r < R) {
        bool flag = false;

        for (int i = 1; i <= d; ++i)
            flag |= chose[a[r + 1][i]];

        if (flag)
            ++r;
        else
            break;
    }

    if (r - l + 1 == ansr - ansl + 1 ? l < ansl : r - l + 1 > ansr - ansl + 1)
        ansl = l, ansr = r;

    if (cnt == k)
        return;

    if (l > L) {
        for (int i = 1; i <= d; ++i)
            chose[a[l - 1][i]] = true, dfs(L, R, cnt + 1, l - 1, r), chose[a[l - 1][i]] = false;
    }

    if (r < R) {
        for (int i = 1; i <= d; ++i)
            chose[a[r + 1][i]] = true, dfs(L, R, cnt + 1, l, r + 1), chose[a[r + 1][i]] = false;
    }
}

void solve(int l, int r) {
    if (l > r)
        return;

    int mid = (l + r) >> 1;
    solve(l, mid - 1), solve(mid + 1, r);

    for (int i = 1; i <= d; ++i)
        chose[a[mid][i]] = true, dfs(l, r, 1, mid, mid), chose[a[mid][i]] = false;
}

signed main() {
    int T;
    scanf("%d", &T);

    for (int task = 1; task <= T; ++task) {
        scanf("%d%d%d", &n, &d, &k);

        for (int i = 1; i <= n; ++i)
            for (int j = 1; j <= d; ++j)
                scanf("%d", a[i] + j);

        ansl = ansr = n + 1, solve(1, n);
        printf("Case #%d: %d %d\n", task, ansl - 1, ansr - 1);
    }

    return 0;
}

CF526F Pudding Monsters

给定一个 \(n \times n\) 的棋盘,其中有 \(n\) 个棋子,每行每列恰好有一个棋子。

对于所有的 \(1 \leq k \leq n\),求有多少个 \(k \times k\) 的子棋盘中恰好有 \(k\) 个棋子。

\(n \le 3 \times 10^5\)

首先转化为统计值域连续区间的数量,即有多少区间满足 \(\max - \min = r - l\)

考虑直接分治:

  • 最大最小值均在左半段:\(mx - mn = j - i \Rightarrow j = i + mx - mn\)
  • 最小值在左半段,最大值在右半段:\(mx_j - mn_i = j - i \Rightarrow mn_i - i = mx_j - j\)
  • 最大最小值均在右半段:与上面对称。
  • 最大值在左半段,最小值在右半段:与上面对称。

时间复杂度 \(O(n \log n)\)

#include <bits/stdc++.h>
typedef long long ll;
using namespace std;
const int N = 3e5 + 7;

int a[N], mx[N], mn[N], cnt[N << 1];

ll ans;
int n;

void solve(int l, int r) {
    if (l == r) {
        ++ans;
        return;
    }
    
    int mid = (l + r) >> 1;
    solve(l, mid), solve(mid + 1, r);
    mn[mid] = mx[mid] = a[mid];

    for (int i = mid - 1; i >= l; --i)
        mn[i] = min(mn[i + 1], a[i]), mx[i] = max(mx[i + 1], a[i]);
    
    mn[mid + 1] = mx[mid + 1] = a[mid + 1];
    
    for (int i = mid + 2; i <= r; ++i)
        mn[i] = min(mn[i - 1], a[i]), mx[i] = max(mx[i - 1], a[i]);
    
    for (int i = mid; i >= l; --i) {
        int j = i + mx[i] - mn[i];
        
        if (mid + 1 <= j && j <= r && mn[i] < mn[j] && mx[j] < mx[i])
            ++ans;
    } // min in left, max in left
    
    for (int j = mid + 1; j <= r; ++j) {
        int i = j - mx[j] + mn[j];
        
        if (l <= i && i <= mid && mn[j] < mn[i] && mx[i] < mx[j])
            ++ans;
    } // min in right, max in right
    
    for (int i = mid, k = mid + 1, j = mid + 1; i >= l; --i) {
        for (; j <= r && mn[j] > mn[i]; ++j)
            ++cnt[mx[j] - j + N];
        
        for (; k < j && mx[k] < mx[i]; ++k)
            --cnt[mx[k] - k + N];
        
        ans += cnt[mn[i] - i + N];
    } // min in left, max in right
    
    for (int i = mid + 1; i <= r; ++i)
        cnt[mx[i] - i + N] = 0;
    
    for (int j = mid + 1, k = mid, i = mid; j <= r; ++j) {
        for (; i >= l && mn[i] > mn[j]; --i)
            ++cnt[mx[i] + i];
        
        for (; k > i && mx[k] < mx[j]; --k)
            --cnt[mx[k] + k];
        
        ans += cnt[mn[j] + j];
    } // min in right, max in left
    
    for (int i = l; i <= mid; ++i)
        cnt[mx[i] + i] = 0;
}

signed main() {
    scanf("%d", &n);
    
    for (int i = 1; i <= n; ++i) {
        int x, y;
        scanf("%d%d", &x, &y);
        a[x] = y;
    }
    
    solve(1, n);
    printf("%lld", ans);
    return 0;
}

基于断点的序列分治

断点:当前分治区间内合法子区间一定不包含的点。

主要思想就是每次按断点将区间分治处理,若没有端点则该区间合法。

寻找断点常用中途相遇法,找一个目标位置时可以从两边开始找,复杂度分析和启发式分裂是一样的。

若还要维护每个点的贡献,则可以钦定处理当前区间时仅存在当前区间的贡献,并在回溯时删除。

每次先删除断点和小区间的贡献,再递归大区间,再加入小区间的贡献,再递归小区间。

值得一提的是这样维护可以支持下标的大小关系不改变。

UVA1608 不无聊的序列 Non-boring sequences

给定 \(a_{1 \sim n}\) ,判断是否存在一个区间,其不存在唯一元素。

\(n \leq 2 \times 10^5\)

考虑分治,每次用唯一元素的位置将其分为两个区间。

采用中途相遇法,从两端同时开始找,若前驱后继都不在该区间内,则其是该区间的唯一元素。

复杂度分析和启发式分裂的复杂度一样,为 \(O(n \log n)\)

#include <bits/stdc++.h>
using namespace std;
const int N = 2e5 + 7;

int a[N], lst[N], nxt[N];

int n;

bool solve(int l, int r) {
    if (l >= r)
        return false;

    int pl = l, pr = r, mid = -1;

    for (int pl = l, pr = r; pl <= pr && mid == -1; ++pl, --pr) {
        if (lst[pl] < l && nxt[pl] > r)
            mid = pl;
        else if (lst[pr] < l && nxt[pr] > r)
            mid = pr;
    }

    return ~mid ? solve(l, mid - 1) || solve(mid + 1, r) : true;
}

signed main() {
    int T;
    scanf("%d", &T);

    while (T--) {
        scanf("%d", &n);
        map<int, int> mp;

        for (int i = 1; i <= n; ++i) {
            scanf("%d", a + i);
            lst[i] = (mp.find(a[i]) == mp.end() ? 0 : mp[a[i]]), mp[a[i]] = i;
        }

        mp.clear();

        for (int i = n; i; --i)
            nxt[i] = (mp.find(a[i]) == mp.end() ? n + 1 : mp[a[i]]), mp[a[i]] = i;

        puts(solve(1, n) ? "boring" : "non-boring");
    }

    return 0;
}

金牌歌手

给定序列 \(a_{1 \sim n}, b_{1 \sim n}\) ,其中 \(b\) 单调不升。求最长区间 \([l, r]\) ,满足区间内的任意元素在区间内的出现次数均 \(\geq b_{r - l + 1}\)

\(n \leq 10^6\)

由于 \(b\) 单调不升,若一个数字使得 \([l, r]\) 不合法,那么 \([l, r]\) 内所有包含该数字的子区间必然不合法。

考虑分治,每次找到一个不合法的位置将区间分为两个部分,问题转化为求当前分治区间内某个数的出现次数。

考虑维护一个桶 \(cnt\) ,钦定每次分治处理时 \(cnt\) 恰为当前区间内数字的出现次数,并在回溯时清空。那么只要每次将小区间里的数字一个个删除,然后递归大区间,再一个个加入小区间里的数,再递归小区间即可。

时间复杂度 \(O(n \log n)\)

#include <bits/stdc++.h>
using namespace std;
const int N = 1e6 + 7;

int a[N], b[N], cnt[N];

int n;

int solve(int l, int r) {
    if (l > r)
        return 0;

    int mid = -1;

    for (int pl = l, pr = r; pl <= pr && mid == -1; ++pl, --pr) {
        if (cnt[a[pl]] < b[r - l + 1])
            mid = pl;
        else if (cnt[a[pr]] < b[r - l + 1])
            mid = pr;
    }

    if (mid == -1) {
        for (int i = l; i <= r; ++i)
            --cnt[a[i]];

        return r - l + 1;
    } else if (mid - l > r - mid) {
        for (int i = mid; i <= r; ++i)
            --cnt[a[i]];

        int res = solve(l, mid - 1);

        for (int i = mid + 1; i <= r; ++i)
            ++cnt[a[i]];

        return max(res, solve(mid + 1, r));
    } else {
        for (int i = l; i <= mid; ++i)
            --cnt[a[i]];

        int res = solve(mid + 1, r);

        for (int i = l; i < mid; ++i)
            ++cnt[a[i]];

        return max(res, solve(l, mid - 1));
    }
}

signed main() {
    scanf("%d", &n);

    for (int i = 1; i <= n; ++i)
        scanf("%d", a + i), ++cnt[a[i]];

    for (int i = 1; i <= n; ++i)
        scanf("%d", b + i);

    printf("%d", solve(1, n));
    return 0;
}

二维分治

一般的分治方式是每次切割边长较大者的中线。

CF364E Empty Rectangles

有一个 \(n \times m\) 的01矩阵,询问有多少个子矩阵和为 \(k\)

\(n, m \leq 2500\)\(k \leq 6\)

考虑跨中线的答案,以竖直切割为例。枚举 \(y\) 坐标作为上下边界,维护 \(left_i, right_i\) 分别表示中线左右矩形中 \(1\) 的个数 \(< i\) 时最多扩展到的位置,统计答案时枚举 \(k\)\(1\) 的分布即可。

#include <bits/stdc++.h>
typedef long long ll;
using namespace std;
const int N = 2.5e3 + 7, K = 11;

int s[N][N];
char str[N];

ll ans;
int n, m, k;

inline int query(int xl, int xr, int yl, int yr) {
    return s[xr][yr] - s[xl - 1][yr] -  s[xr][yl - 1] + s[xl - 1][yl - 1];
}

void solve(int xl, int xr, int yl, int yr) {
    if (xl == xr && yl == yr) {
        ans += (query(xl, xr, yl, yr) == k);
        return;
    }
    
    if (xr - xl > yr - yl) {
        int mid = (xl + xr) >> 1;
        solve(xl, mid, yl, yr), solve(mid + 1, xr, yl, yr);
        
        for (int l = yl; l <= yr; ++l) {
            vector<int> left(k + 2, xl), right(k + 2, xr);
            left[0] = mid + 1, right[0] = mid;
            
            for (int r = l; r <= yr; ++r) {
                for (int i = 1; i <= k + 1; ++i) {
                    while (query(left[i], mid, l, r) >= i)
                        ++left[i];
                    
                    while (query(mid + 1, right[i], l, r) >= i)
                        --right[i];
                }
                
                for (int i = 0; i <= k; ++i)
                    ans += 1ll * (left[i] - left[i + 1]) * (right[k - i + 1] - right[k - i]);
            }
        }
    } else {
        int mid = (yl + yr) >> 1;
        solve(xl, xr, yl, mid), solve(xl, xr, mid + 1, yr);
        
        for (int l = xl; l <= xr; ++l) {
            vector<int> down(k + 2, yl), up(k + 2, yr);
            down[0] = mid + 1, up[0] = mid;
            
            for (int r = l; r <= xr; ++r) {
                for (int i = 1; i <= k + 1; ++i) {
                    while (query(l, r, down[i], mid) >= i)
                        ++down[i];
                    
                    while (query(l, r, mid + 1, up[i]) >= i)
                        --up[i];
                }
                
                for (int i = 0; i <= k; ++i)
                    ans += 1ll * (down[i] - down[i + 1]) * (up[k - i + 1] - up[k - i]);
            }
        }
    }
}

signed main() {
    scanf("%d%d%d", &n, &m, &k);
    
    for (int i = 1; i <= n; ++i) {
        scanf("%s", str + 1);
        
        for (int j = 1; j <= m; ++j)
            s[i][j] = s[i - 1][j] + s[i][j - 1] - s[i - 1][j - 1] + (str[j] == '1');
    }
    
    solve(1, n, 1, m);
    printf("%lld", ans);
    return 0;
}

笛卡尔树分治

即最值分治,每次按最值的位置将区间分为两部分,可以较好地处理区间最值的限制。

需要保证每层的复杂度仅与小区间有关,否则会退化到平方级别。

通常是按照启发式分裂的套路,枚举短区间,算长区间的贡献。

P4755 Beautiful Pair

给定序列 \(a_{1 \sim n}\) ,求有多少对 \(i \leq j\) 满足 \(a_i \times a_j \leq \max_{k = i}^j a_k\)

\(n \leq 10^5\)

构建笛卡尔树结构,设当前分治到区间 \([l, r]\) ,当前笛卡尔树上的节点在原序列上的编号是 \(p\)

根据分治套路,我们只需要计算跨过 \(p\) 的答案,然后 \([l, p - 1]\)\([p + 1, r]\) 分治处理。

以左边为短区间为例,考虑枚举左边,算右边的贡献。可以枚举点对的左端点 \(l\) ,则 \(a_r \leq \frac{a_p}{a_l}\)

时间复杂度 \(O(n \log^2 n)\)

#include <bits/stdc++.h>
typedef long long ll;
using namespace std;
const int N = 1e5 + 7;

int a[N];

ll ans;
int n;

template <class T = int>
inline T read() {
    char c = getchar();
    bool sign = c == '-';
    
    while (c < '0' || c > '9')
        c = getchar(), sign |= c == '-';
    
    T x = 0;
    
    while ('0' <= c && c <= '9')
        x = (x << 1) + (x << 3) + (c & 15), c = getchar();
    
    return sign ? (~x + 1) : x;
}

namespace SMT {
const int S = N << 5;

int s[S], lc[S], rc[S];
int rt[N];

int tot;

int insert(int x, int nl, int nr, int k) {
    int y = ++tot;
    lc[y] = lc[x], rc[y] = rc[x], s[y] = s[x] + 1;
    
    if (nl == nr)
        return y;
    
    int mid = (nl + nr) >> 1;
    
    if (k <= mid)
        lc[y] = insert(lc[x], nl, mid, k);
    else
        rc[y] = insert(rc[x], mid + 1, nr, k);
    
    return y;
}

int query(int x, int y, int nl, int nr, int k) {
    if (nl == nr)
        return s[y] - s[x];
    
    int mid = (nl + nr) >> 1;
    
    if (k <= mid)
        return query(lc[x], lc[y], nl, mid, k);
    else
        return (s[lc[y]] - s[lc[x]]) + query(rc[x], rc[y], mid + 1, nr, k);
}
} // namespace SMT

namespace CST {
int lc[N], rc[N];

int root;

inline void build() {
    static int sta[N];

    for (int i = 1, top = 0; i <= n; ++i) {
        int k = top;
        
        while (k && a[sta[k]] < a[i])
            --k;
        
        if (k)
            rc[sta[k]] = i;
        else
            root = i;
        
        if (k < top)
            lc[i] = sta[k + 1];
        
        sta[top = ++k] = i;
    }
}

void solve(int x, int l, int r) {
    if (x - l <= r - x) {
        for (int i = l; i <= x; ++i)
            ans += SMT::query(SMT::rt[x - 1], SMT::rt[r], 1, 1e9, a[x] / a[i]);
    } else {
        for (int i = x; i <= r; ++i)
            ans += SMT::query(SMT::rt[l - 1], SMT::rt[x], 1, 1e9, a[x] / a[i]);
    }
    
    if (lc[x])
        solve(lc[x], l, x - 1);
    
    if (rc[x])
        solve(rc[x], x + 1, r);
}
} // namespace CST

signed main() {
    n = read();
    
    for (int i = 1; i <= n; ++i)
        SMT::rt[i] = SMT::insert(SMT::rt[i - 1], 1, 1e9, a[i] = read());
    
    CST::build(), CST::solve(CST::root, 1, n);
    printf("%lld", ans);
    return 0;
}

Safe Partition

给定 \(a_{1 \sim n}\) ,需要将其划分为若干段,每段 \(S\) 均要满足 \(\min_{x \in S} a_x \leq |S| \leq \max_{x \in S} a_x\) ,求划分方案数 \(\bmod 10^9 + 7\)

\(n \leq 5 \times 10^5\)

\(f_i\) 表示 \(1 \sim i\) 的划分方案数,分别考虑两个限制。

对于 \(\min_{x \in S} a_x \leq |S|\) 的限制,扩大 \(S\) 时,前者不升而后者递增,记 \(L_i, R_i\) 表示 \(j \in [1, L_i] \cup [R_i, n]\)\([j, i]\) 合法,由于 \(L_{i - 1} \leq L_i, R_i \leq R_{i + 1}\) ,因此可以维护指针线性求出。

对于 \(|S| \leq \max_{x \in S} a_x\) 的限制,扩大 \(S\) 时,前者不降而后者递增,直接处理是困难的。考虑建出大根笛卡尔树,记当前处理区间为 \([l, r]\) ,最大值在 \(x\) ,需要统计 \(l \leq i \leq x \leq j \leq r\) 且满足 \([i, j]\) 合法的 \(f_{i - 1} \to j\) 的转移。按照套路,需要枚举小区间。

  • 左区间更小:枚举所有 \(i\) ,则 \(j \in [\max(R_i, x), \min(i + a_x - 1, r)]\) ,直接维护差分做区间加即可。
  • 右区间更小:枚举所有 \(j\) ,则 \(i \in [\max(j - a_x + 1, l), \min(L_i, x)]\) ,直接维护前缀和查询区间和即可。

时间复杂度 \(O(n \log n)\)

#include <bits/stdc++.h>
using namespace std;
const int Mod = 1e9 + 7;
const int N = 5e5 + 7;

int a[N], L[N], R[N], f[N], s[N], c[N];

int n;

inline int add(int x, int y) {
    x += y;
    
    if (x >= Mod)
        x -= Mod;
    
    return x;
}

inline int dec(int x, int y) {
    x -= y;
    
    if (x < 0)
        x += Mod;
    
    return x;
}

namespace CST {
int lc[N], rc[N], sta[N];

int root;

inline void build() {
    for (int i = 1, top = 0; i <= n; ++i) {
        int k = top;

        while (k && a[sta[k]] <= a[i])
            --k;

        if (k)
            rc[sta[k]] = i;
        else
            root = i;

        if (k < top)
            lc[i] = sta[k + 1];

        sta[top = ++k] = i;
    }
}

void solve(int x, int l, int r) {
    if (lc[x])
        solve(lc[x], l, x - 1);

    if (x - l < r - x) {
        for (int i = l; i <= x; ++i) {
            int p = max(R[i], x), q = min(i + a[x] - 1, r);

            if (p <= q)
           		c[p] = add(c[p], f[i - 1]), c[q + 1] = dec(c[q + 1], f[i - 1]);
        }
    } else {
        for (int i = x; i <= r; ++i) {
            int p = max(i - a[x] + 1, l), q = min(L[i], x);

            if (p <= q) {
                if (p >= 2)
                    f[i] = add(f[i], dec(s[q - 1], s[p - 2]));
                else
                    f[i] = add(f[i], s[q - 1]);
            }
        }
    }
	
    s[x] = add(s[x - 1], f[x] = add(f[x], c[x] = add(c[x], c[x - 1])));

    if (rc[x])
        solve(rc[x], x + 1, r);
}
} // namespace CST

signed main() {
    scanf("%d", &n);

    for (int i = 1; i <= n; ++i)
        scanf("%d", a + i);

    deque<int> q;

    for (int i = 1, j = 0; i <= n; ++i) {
        while (!q.empty() && a[q.back()] >= a[i])
        	q.pop_back();
        
        q.emplace_back(i);
        
        while (j + 1 <= i && a[q.front()] <= i - j)
        	if (q.front() == (++j))
        		q.pop_front();

        L[i] = j;
    }
    
    q.clear();

    for (int i = n, j = n + 1; i; --i) {
        while (!q.empty() && a[q.back()] >= a[i])
        	q.pop_back();
        
        q.emplace_back(i);
        
        while (j - 1 >= i && a[q.front()] <= j - i)
        	if (q.front() == (--j))
        		q.pop_front();

        R[i] = j;
    }

    CST::build(), s[0] = f[0] = 1, CST::solve(CST::root, 1, n);
    printf("%d", f[n]);
    return 0;
}

线段树分治

线段树分治是一种维护时间区间的数据结构,并利用线段树的性质使得复杂度保证在 \(\log\) 级别。

假设一个操作影响的时间区间是 \([L, R]\) ,将其放到线段树上就会分成 \(O(\log n)\) 段小区间。

对于一个询问,我们只要在线段树上找到设个询问所在的时间点,把根到叶子节点上路径的所有影响合并起来就可以得出答案。

一般选择一些可以支持撤销操作的数据结构,这样遍历完一个子树后撤销操作即可。

由于线段树分治会把一个状态传递给两个儿子(类似可持久化的操作树),而并不是经典的一条时间轴,所以均摊分析在线段树分治中会失效。

P5787 二分图 /【模板】线段树分治

给出一个无向图,每条边都有一个存在时间区间,询问每个时刻该图是不是二分图。

\(n \leq 10^5, m \leq 2 \times 10^5\)

用扩展域并查集判定二分图即可。

#include <bits/stdc++.h>
using namespace std;
const int N = 1e5 + 7, M = 2e5 + 7;

struct Node {
	int x, y, add;
};

struct Edge {
	int u, v;
} e[M];

int n, m, k;

template <class T = int>
inline T read() {
	char c = getchar();
	bool sign = (c == '-');
	
	while (c < '0' || c > '9')
		c = getchar(), sign |= (c == '-');
	
	T x = 0;
	
	while ('0' <= c && c <= '9')
		x = (x << 1) + (x << 3) + (c & 15), c = getchar();
	
	return sign ? (~x + 1) : x;
}

namespace DSU {
stack<pair<int, int> > sta;

int fa[N << 1], siz[N << 1];

inline void prework(int n) {
	iota(fa + 1, fa + 1 + n, 1);
	fill(siz + 1, siz + 1 + n, 1);
}

inline int find(int x) {
    while (x != fa[x])
        x = fa[x];

    return x;
}

inline void merge(int x, int y) {
	x = find(x), y = find(y);
	
	if (siz[x] < siz[y])
		swap(x, y);

	sta.emplace(y, siz[x] == siz[y]);
	siz[x] += (siz[x] == siz[y]), fa[y] = x;
}

inline void restore(int top) {
	while (sta.size() > top) {
		int x = sta.top().first, k = sta.top().second;
		sta.pop();
		siz[fa[x]] -= k, fa[x] = x;
	}
}
} // namespace DSU

namespace SMT {
vector<pair<int, int> > upd[N << 2];

inline int ls(int x) {
	return x << 1;
}

inline int rs(int x) {
	return x << 1 | 1;
}

void update(int x, int nl, int nr, int l, int r, auto k) {
	if (l <= nl && nr <= r) {
		upd[x].push_back(k);
		return;
	}
	
	int mid = (nl + nr) >> 1;
	
	if (l <= mid)
		update(ls(x), nl, mid, l, r, k);
	
	if(r > mid)
		update(rs(x), mid + 1, nr, l, r, k);
}

void dfs(int x, int l, int r) {
	int top = DSU::sta.size();
	bool flag = true;
	
	for (auto it : upd[x]) {
		int u = it.first, v = it.second;

		if (DSU::find(u) == DSU::find(v)) {
			for (int k = l; k <= r; ++k)
				puts("No");
			
			flag = false;
			break;
		}
		
		DSU::merge(u, v + n), DSU::merge(v, u + n);
	}
	
	if (flag) {
		if (l == r)
			puts("Yes");
		else {
			int mid = (l + r) >> 1;
			dfs(ls(x), l, mid), dfs(rs(x), mid + 1, r);
		}
	}
	
	DSU::restore(top);
}
} // namespace SMT

signed main() {
	n = read(), m = read(), k = read();
	
	for (int i = 1; i <= m; ++i) {
		int u = read(), v = read(), l = read(), r = read();
		
		if (l < r)
			SMT::update(1, 1, k, l + 1, r, make_pair(u, v));
	}
	
	DSU::prework(n * 2);
	SMT::dfs(1, 1, k);
	return 0;
}

P5416 [CTSC2016]时空旅行

维护若干集合,每个集合都是由一个编号小于它的集合扩展而来,扩展内容为加入一个二元组 \((x, c)\) 或删除一个二元组 \((x, c)\) 。集合的扩展关系构成一个树形结构。 \(m\) 次询问第 \(s\) 个集合中 \((x - X)^2 + c\) 的最小值。

\(n, m \leq 5 \times 10^5\)

我们将树形结构建立出来之后,可以发现一个元素就是它第一次出现的点的子树中去掉把它删除点的子树,即可以看作一段DFS序连续的区间。

我们把这个东西丢到线段树上,考虑如何计算答案。我们在线段树的每个节点维护一个关于 \((x - X)^2 + c\) 的凸壳,然后把询问都插到线段树里,递归整棵线段树即可。维护凸壳可以在插入前先将凸壳的先后顺序排好再插入。

时间复杂度 \(O(n \log n)\)

#include <bits/stdc++.h>
typedef long long ll;
using namespace std;
const ll inf = 1e18;
const int N = 5e5 + 7;

struct Graph {
	vector<int> e[N];
	
	inline void insert(int u, int v) {
		e[u].emplace_back(v);
	}
} G;

struct Query {
	int x, k, id;
	
	inline bool operator < (const Query &b) const {
		return k < b.k;
	}
} qry[N];

vector<int> w[N << 2];
vector<int> updl[N], updr[N];

ll cost[N], ans[N], val[N];
int L[N << 2], R[N << 2];
int  a[N], dfn[N], tag[N];

int n, m, cnt, dfstime;

template <class T = int>
inline T read() {
	char c = getchar();
	bool sign = (c == '-');
	
	while (c < '0' || c > '9')
		c = getchar(), sign |= (c == '-');
	
	T x = 0;
	
	while ('0' <= c && c <= '9')
		x = (x << 1) + (x << 3) + (c & 15), c = getchar();
	
	return sign ? (~x + 1) : x;
}

void dfs(int u, int fa) {
	dfn[u] = ++dfstime;
	
	if (tag[u] > 0)
		updl[tag[u]].emplace_back(dfstime);
	else if (tag[u] < 0)
		updr[-tag[u]].emplace_back(dfstime - 1);
	
	for (int v : G.e[u])
		if (v != fa)
			dfs(v, u);
	
	if (tag[u] > 0)
		updr[tag[u]].emplace_back(dfstime);
	else if (tag[u] < 0)
		updl[-tag[u]].emplace_back(dfstime + 1);
}

inline int ls(int x) {
	return x << 1;
}

inline int rs(int x) {
	return x << 1 | 1;
}

inline double slope(int x, int y) {
	return (double) (val[x] * val[x] + cost[x] - val[y] * val[y] - cost[y]) / (val[x] - val[y]);
}

void update(int x, int nl, int nr, int l, int r, int k) {
	if (l == nl && nr == r) {
		w[x].resize(R[x] + 2);
		
		if (L[x] <= R[x] && val[w[x][R[x]]] == val[k]) {
			if (cost[w[x][R[x]]] <= cost[k])
				return;
			
			--R[x];
		}
		
		while (L[x] < R[x] && slope(w[x][R[x]], k) < slope(w[x][R[x]], w[x][R[x] - 1]))
			--R[x];
		
		w[x][R[x] + 1] = k, ++R[x];
		return;
	}
	
	int mid = (nl + nr) >> 1;
	
	if (r <= mid)
		update(ls(x), nl, mid, l, r, k);
	else if (l > mid)
		update(rs(x), mid + 1, nr, l, r, k);
	else
		update(ls(x), nl, mid, l, mid, k), update(rs(x), mid + 1, nr, mid + 1, r, k);
}

ll query(int x, int nl, int nr, int pos, int k) {
	while (L[x] < R[x] && slope(w[x][L[x]], w[x][L[x] + 1]) <= 2.0 * k)
		++L[x];

	ll res = L[x] <= R[x] && !w[x].empty() ? 1ll * (k - val[w[x][L[x]]]) * (k - val[w[x][L[x]]]) + cost[w[x][L[x]]] : inf;
	
	if (nl == nr)
		return res;

	int mid = (nl + nr) >> 1;
	
	if (pos <= mid)
		return min(res, query(ls(x), nl, mid, pos, k));
	else
		return min(res, query(rs(x), mid + 1, nr, pos, k));
}

signed main() {
	n = read(), m = read(), cost[0] = read<ll>();
	
	for (int i = 1; i < n; ++i) {
		int op = read(), u = read();
		ll x = read<ll>();
		
		if (op)
			tag[i] = -x;
		else {
			val[x] = read<ll>(), read(), read(), cost[x] = read<ll>();
			tag[i] = x;
		}
		
		G.insert(i, u), G.insert(u, i);
	}
	
	dfs(0, 0);
	memset(R, -1, sizeof(R));
	update(1, 1, n, 1, n, 0);
	vector<int> id(n);
	iota(id.begin(), id.end(), 1);
	sort(id.begin(), id.end(), [](const int &a, const int &b) { return val[a] < val[b]; });
	
	for (int x : id)
		for (int j = 0; j < updl[x].size(); ++j)
			if (updl[x][j] <= updr[x][j])
				update(1, 1, n, updl[x][j], updr[x][j], x);
	
	for (int i = 1; i <= m; ++i)
		qry[i].x = read(), qry[i].k = read(), qry[i].id = i;
	
	sort(qry + 1, qry + m + 1);
	
	for (int i = 1; i <= m; ++i)
		ans[qry[i].id] = query(1, 1, n, dfn[qry[i].x], qry[i].k);
	
	for (int i = 1; i <= m; ++i)
		printf("%lld\n", ans[i]);
	
	return 0;
}

P5227 [AHOI2013] 连通图

给定一张无向图,每次询问删除 \(c_i\) 条边,求整张图是否连通,询问之间独立。

\(n, q \leq 2 \times 10^5\)\(m \leq 2 \times 10^5\)\(c_i \leq 4\)

\(cnt_i\) 表示第 \(i\) 条边被删掉了几次,递归左区间时就将右区间涉及到的边的 \(cnt\) 减一,当 \(cnt\) 变为 \(0\) 时加入这条边,递归处理即可,记得回溯时撤销。

本质和处理每条边的出现时间是一样的,时间复杂度 \(O(qc \log q \log n)\)

#include <bits/stdc++.h>
using namespace std;
const int N = 1e5 + 7, M = 2e5 + 7;

struct Edge {
    int u, v;
} e[M];

struct DSU {
    int fa[N], siz[N], sta[N];

    int top;
    
    inline void prework(int n) {
        iota(fa + 1, fa + n + 1, 1);
        fill(siz + 1, siz + n + 1, 1);
    }
    
    inline int find(int x) {
        while (x != fa[x])
            x = fa[x];
    
        return x;
    }
    
    inline void merge(int x, int y) {
        x = find(x), y = find(y);

        if (x == y)
            return;

        if (siz[x] < siz[y])
            swap(x, y);

        sta[++top] = y, siz[fa[y] = x] += siz[y];
    }

    inline void restore(int k) {
        while (top > k) {
            int x = sta[top--];
            siz[fa[x]] -= siz[x], fa[x] = x;
        }
    }
} dsu;

vector<int> rmv[N];

int cnt[M];

int n, m, q;

template <class T = int>
inline T read() {
    char c = getchar();
    bool sign = (c == '-');
    
    while (c < '0' || c > '9')
        c = getchar(), sign |= (c == '-');
    
    T x = 0;
    
    while ('0' <= c && c <= '9')
        x = (x << 1) + (x << 3) + (c & 15), c = getchar();
    
    return sign ? (~x + 1) : x;
}

void solve(int l, int r) {
    if (dsu.top == n - 1) {
        for (int i = l; i <= r; ++i)
            puts("Connected");

        return;
    }

    if (l == r) {
        puts("Disconnected");
        return;
    }

    int mid = (l + r) >> 1, top = dsu.top;

    for (int i = mid + 1; i <= r; ++i)
        for (int it : rmv[i]) {
            --cnt[it];

            if (!cnt[it])
                dsu.merge(e[it].u, e[it].v);
        }

    solve(l, mid), dsu.restore(top);

    for (int i = mid + 1; i <= r; ++i)
        for (int it : rmv[i])
            ++cnt[it];

    for (int i = l; i <= mid; ++i)
        for (int it : rmv[i]) {
            --cnt[it];

            if (!cnt[it])
                dsu.merge(e[it].u, e[it].v);
        }

    solve(mid + 1, r), dsu.restore(top);

    for (int i = l; i <= mid; ++i)
        for (int it : rmv[i])
            ++cnt[it];
}

signed main() {
    n = read(), m = read();

    for (int i = 1; i <= m; ++i)
        e[i].u = read(), e[i].v = read();

    q = read();

    for (int i = 1; i <= q; ++i) {
        rmv[i].resize(read());

        for (int &it : rmv[i])
            ++cnt[it = read()];
    }

    dsu.prework(n);

    for (int i = 1; i <= m; ++i)
        if (!cnt[i])
            dsu.merge(e[i].u, e[i].v);
        
    solve(1, q);
    return 0;
}

P3206 [HNOI2010] 城市建设

给定一张图支持动态的修改边权,要求在每次修改边权之后输出这张图 MST 的边权和。

\(n \leq 2 \times 10^4\)\(m, q \leq 5 \times 10^4\)

考虑对时间进行线段树分治,假设现在处理 \([l, r]\) ,称 \([l, r]\) 涉及到的边为动态边,其余边为静态边。

每层先用并查集将一定没用的静态边删去,从而简化了静态边集。

接下来考虑如何将 \([l, r]\) 拆分为 \([l, mid]\)\([mid + 1, r]\) 递归处理,可以发现另一个区间的动态边会变为静态边,于是动态边会不断的变成静态边,最后变成一个纯静态的MST。

对于静态边的处理,实现上有两个 trick:

  • 若都不选动态边跑一遍 MST ,此时不在 MST 中的边就可以直接删除了。
    • 这样保证了每一层边的规模与点同阶。
  • 若先选动态边后跑一遍 MST ,此时在 MST 中的静态边递归时一定仍在MST中,因此可以直接将其统计贡献并将这些边所连接的连通块缩点。
    • 这样保证了点的规模为区间长度。

注意左区间对右区间和右区间对左区间的贡献处理有一点小差异,用可撤销并查集维护,时间复杂度 \(O(n \log^2 n)\)

#include <bits/stdc++.h>
typedef long long ll;
using namespace std;
const int N = 2e4 + 7, M = 5e4 + 7, LOGN = 21;

struct DSU {
    int fa[N], siz[N], sta[N];

    int top;

    inline void prework(int n) {
        iota(fa + 1, fa + n + 1, 1), fill(siz + 1, siz + n + 1, 1);
    }

    inline int find(int x) {
        while (x != fa[x])
            x = fa[x];

        return x;
    }

    inline void merge(int x, int y) {
        x = find(x), y = find(y);

        if (x == y)
            return;

        if (siz[x] < siz[y])
            swap(x, y);

        siz[x] += siz[y], fa[y] = x, sta[++top] = y;
    }

    inline void cancel() {
        int y = sta[top--], x = fa[y];
        fa[y] = y, siz[x] -= siz[y];
    }

    inline void restore(int k = 0) {
        while (top > k)
            cancel();
    }
} dsu1, dsu2;

struct Edge {
    int u, v, w, tag;

    inline Edge(int _u = 0, int _v = 0, int _w = 0) : u(_u), v(_v), w(_w), tag(0) {}

    inline bool operator < (const Edge &rhs) const {
        return w < rhs.w;
    }
} e[M];

struct Update {
    int x, k;
} upd[M];

vector<Edge> stc[LOGN], dnc; // static edge / dynamic edge

ll ans[M], res[LOGN];
bool vis[M];

int n, m, q;

inline void pushdown(int d) {
    vector<Edge> vec = stc[d];
    sort(vec.begin(), vec.end());

    for (Edge &it : vec) {
        if (dsu2.find(it.u) == dsu2.find(it.v))
            it.tag = -1; // useless edge
        else
            dsu2.merge(it.u, it.v);
    }

    dsu2.restore();

    for (Edge it : dnc)
        dsu2.merge(it.u, it.v);

    dnc.clear(), res[d + 1] = res[d];

    for (Edge &it : vec) { // essential edge
        if (it.tag == -1 || dsu2.find(it.u) == dsu2.find(it.v))
            continue;

        dsu2.merge(it.u, it.v), dsu1.merge(it.u, it.v);
        it.tag = 1, res[d + 1] += it.w;
    }

    dsu2.restore(), stc[d + 1].clear();

    for (Edge it : vec)
        if (!it.tag && dsu1.find(it.u) != dsu1.find(it.v))
            stc[d + 1].emplace_back(dsu1.find(it.u), dsu1.find(it.v), it.w);
}

void solve(int l, int r, int d) {
    if (l == r) {
        stc[d].emplace_back(dsu1.find(e[upd[l].x].u), dsu1.find(e[upd[l].x].v), upd[l].k);
        e[upd[l].x].w = upd[l].k, pushdown(d), ans[l] = res[d + 1], stc[d].pop_back();
        return;
    }

    int mid = (l + r) >> 1, lsttop = dsu1.top;

    for (int i = l; i <= mid; ++i) // update -> dynamic
        dnc.emplace_back(dsu1.find(e[upd[i].x].u), dsu1.find(e[upd[i].x].v)), vis[upd[i].x] = true;

    for (int i = mid + 1; i <= r; ++i) // dynamic -> static
        if (!vis[upd[i].x])
            stc[d].emplace_back(dsu1.find(e[upd[i].x].u), dsu1.find(e[upd[i].x].v), e[upd[i].x].w);

    pushdown(d);

    for (int i = mid + 1; i <= r; ++i)
        if (!vis[upd[i].x])
            stc[d].pop_back();

    for (int i = l; i <= mid; ++i)
        vis[upd[i].x] = false;

    solve(l, mid, d + 1), dsu1.restore(lsttop);

    for (int i = mid + 1; i <= r; ++i)
        vis[upd[i].x] = true;

    for (int i = l; i <= mid; ++i) // dynamic -> static
        if (!vis[upd[i].x])
            stc[d].emplace_back(dsu1.find(e[upd[i].x].u), dsu1.find(e[upd[i].x].v), e[upd[i].x].w);

    for (int i = mid + 1; i <= r; ++i) // update -> dynamic
        dnc.emplace_back(dsu1.find(e[upd[i].x].u), dsu1.find(e[upd[i].x].v)), vis[upd[i].x] = false;

    pushdown(d), solve(mid + 1, r, d + 1), dsu1.restore(lsttop);
}

signed main() {
    scanf("%d%d%d", &n, &m, &q);

    for (int i = 1; i <= m; ++i)
        scanf("%d%d%d", &e[i].u, &e[i].v, &e[i].w);

    for (int i = 1; i <= q; ++i) {
        scanf("%d%d", &upd[i].x, &upd[i].k);
        dnc.emplace_back(e[upd[i].x]), vis[upd[i].x] = true;
    }

    for (int i = 1; i <= m; ++i)
        if (!vis[i])
            stc[1].emplace_back(e[i]);

    for (int i = 1; i <= q; ++i)
        vis[upd[i].x] = false;

    dsu1.prework(n), dsu2.prework(n), solve(1, q, 1);

    for (int i = 1; i <= q; ++i)
        printf("%lld\n", ans[i]);

    return 0;
}

猫树分治

对于一类题目,其具有如下特征

  • 询问数量庞大:单次询问尽量不能带 \(\log\)
  • 询问可以离线。
  • 空间要求严苛:否则直接建立猫树在线询问即可。

考虑分治区间 \([l, r]\) 的处理。首先 \(l = r\) 的情况是 trivial 的,否则对于 \(i \in [l, mid]\) ,扫一遍处理 \([i, mid]\) 的答案;对于 \(i \in [mid + 1, r]\) ,扫一遍处理 \([mid + 1, r]\) 的答案。对于一个询问 \([ql, qr]\)(这里只考虑跨过 \(mid\) 的情况,否则可以递归处理),将 \([ql, mid]\)\([mid + 1, qr]\) 的答案合并即可。

不难发现,上述做法通过分治把 \(\log\) 的复杂度和单次插入的复杂度摊到了一起,这样每次询问时只用合并一次。记插入的复杂度为 \(T(n)\) ,那么上述算法的时间复杂度就是 \(O(n \log n + q \times T(n))\)

基于中点的序列分治通常是维护全局信息,维护区间信息的分治结构通常被称为猫树分治。

本质就是建出分治树,然后把所有猫树上同一节点的询问一起处理。

P6240 好吃的题目

给定 \(n\) 个物品,每个物品有体积和价值,\(m\) 次询问区间 \([l, r]\) 内的物品做01背包后容量为 \(t\) 的最大价值。

\(n \leq 4 \times 10^4\)\(m \leq 2 \times 10^5\)\(t \leq 200\)

如果直接建出猫树,空间复杂度为 \(O(n \log n \times t)\) ,无法通过。

如果上线段树,时间复杂度为 \(O(m \log n \times t)\) ,无法通过。

采用猫树分治,时间复杂度 \(O(n \log n \times t + qt)\) ,空间复杂度 \(O(nt)\) ,可以通过。

#include <bits/stdc++.h>
using namespace std;
const int N = 4e4 + 7, M = 2e5 + 7, V = 2e2 + 7;

struct Query {
    int l, r, t, id;
};

int f[N][V];
int h[N], w[N], ans[M];

int n, m;

void solve(int l, int r, vector<Query> &qry) {
    if (l == r) {
        for (Query it : qry)
            ans[it.id] = (h[l] <= it.t ? w[l] : 0);

        return;
    }

    int mid = (l + r) >> 1;
    memset(f[mid], 0, sizeof(f[mid])), fill(f[mid] + h[mid], f[mid] + V, w[mid]);

    for (int i = mid - 1; i >= l; --i) {
        memcpy(f[i], f[i + 1], sizeof(f[i]));

        for (int j = V - 1; j >= h[i]; --j)
            f[i][j] = max(f[i + 1][j], f[i + 1][j - h[i]] + w[i]);
    }

    memset(f[mid + 1], 0, sizeof(f[mid + 1])), fill(f[mid + 1] + h[mid + 1], f[mid + 1] + V, w[mid + 1]);

    for (int i = mid + 2; i <= r; ++i) {
        memcpy(f[i], f[i - 1], sizeof(f[i]));

        for (int j = V - 1; j >= h[i]; --j)
            f[i][j] = max(f[i - 1][j], f[i - 1][j - h[i]] + w[i]);
    }

    vector<Query> ql, qr;

    for (Query it : qry) {
        if (it.r <= mid)
            ql.emplace_back(it);
        else if (it.l > mid)
            qr.emplace_back(it);
        else {
            for (int i = 0; i <= it.t; ++i)
                ans[it.id] = max(ans[it.id], f[it.l][i] + f[it.r][it.t - i]);
        }
    }

    solve(l, mid, ql), solve(mid + 1, r, qr);
}

signed main() {
    scanf("%d%d", &n, &m);

    for (int i = 1; i <= n; ++i)
        scanf("%d", h + i);

    for (int i = 1; i <= n; ++i)
        scanf("%d", w + i);

    vector<Query> qry;

    for (int i = 1; i <= m; ++i) {
        int l, r, t;
        scanf("%d%d%d", &l, &r, &t);
        qry.emplace_back((Query) {l, r, t, i});
    }

    solve(1, n, qry);

    for (int i = 1; i <= m; ++i)
        printf("%d\n", ans[i]);

    return 0;
}

P6109 [Ynoi2009] rprmq1

给出一个 \(n \times n\) 的矩阵,先有 \(m\) 次矩形加某数的操作,然后 \(q\) 次查询矩形最大值。

\(n, m \leq 5 \times 10^4\)\(q \leq 5 \times 10^5\)

先将修改差分,用线段树维护 \(y\) 轴。

考虑询问时 \(l_1 = r_1\) 要怎么做,只要先进行 \(x \leq l_1\) 的修改,然后查询 \([l_2, r_2]\) 的区间最大值即可。

然后考虑一般情况,暴力的想法是枚举所有 \(x \in [l_1, r_2]\) ,按 \(l_1 = r_1 = x\) 的方法处理。不难发现,若记 \(x = l_1\) 为初始版本,则将 \(l_1 \sim r_1\) 的修改操作全部完成后,答案即为 \([l_2, r_2]\) 的历史最大值。注意修改需要先负后正才能保证正确性。

由于要查询某一段版本的历史最大值,直接处理是困难的。因为询问数量较大,因此考虑猫树分治。分治时钦定 \(x < l\) 的修改全部加入,考虑处理 \(l_1 \leq mid < r_1\) 的询问。

先加入 \(x \leq mid\) 的修改,记当前版本为初始版本,然后对右半部分的做扫描线,询问就是询问区间历史最大值,即可得出右半边的答案。

接下来考虑左半部分的答案。先将前面所有右半边的修改撤销,记当前版本为初始版本。然后倒序对左半部分扫描线,原来的修改就变成撤销操作,查询同样是区间历史最大值,即可得出左半边的答案。

对于递归区间,由于需要保证 \(< l\) 的修改都加入,可以先递归左区间,然后处理右半部分答案,然后递归右区间,然后处理左半部分的答案。

时间复杂度 \(O(n \log^2 n + q \log n)\)

#include <bits/stdc++.h>
typedef long long ll;
using namespace std;
const int N = 5e5 + 7;

struct Update {
    int l, r, k;

    inline bool operator < (const Update &rhs) const {
        return k < rhs.k;
    }
};

struct Query {
    int xl, yl, xr, yr;
} qry[N];

vector<Update> upd[N];

ll ans[N];

int n, m, q;

template <class T = int>
inline T read() {
    char c = getchar();
    bool sign = (c == '-');
    
    while (c < '0' || c > '9')
        c = getchar(), sign |= (c == '-');
    
    T x = 0;
    
    while ('0' <= c && c <= '9')
        x = (x << 1) + (x << 3) + (c & 15), c = getchar();
    
    return sign ? (~x + 1) : x;
}

namespace SMT {
ll mx[N << 2], hismx[N << 2], tag[N << 2], histag[N << 2];
bool rst[N];

inline int ls(int x) {
    return x << 1;
}

inline int rs(int x) {
    return x << 1 | 1;
}

inline void pushup(int x) {
    mx[x] = max(mx[ls(x)], mx[rs(x)]);
    hismx[x] = max(hismx[ls(x)], hismx[rs(x)]);
}

inline void spread(int x, ll k, ll hisk) {
    histag[x] = max(histag[x], tag[x] + hisk), tag[x] += k;
    hismx[x] = max(hismx[x], mx[x] + hisk), mx[x] += k;
}

inline void reset(int x) {
    spread(ls(x), tag[x], histag[x]), spread(rs(x), tag[x], histag[x]);
    rst[x] = true, hismx[x] = mx[x], histag[x] = tag[x] = 0;
}

inline void pushdown(int x) {
    if (rst[x])
        reset(ls(x)), reset(rs(x)), rst[x] = false;

    spread(ls(x), tag[x], histag[x]), spread(rs(x), tag[x], histag[x]), histag[x] = tag[x] = 0;
}

void update(int x, int nl, int nr, int l, int r, int k) {
    if (l <= nl && nr <= r) {
        spread(x, k, k);
        return;
    }

    pushdown(x);
    int mid = (nl + nr) >> 1;

    if (l <= mid)
        update(ls(x), nl, mid, l, r, k);

    if (r > mid)
        update(rs(x), mid + 1, nr, l, r, k);

    pushup(x);
}

ll query(int x, int nl, int nr, int l, int r) {
    if (l <= nl && nr <= r)
        return hismx[x];

    pushdown(x);
    int mid = (nl + nr) >> 1;

    if (r <= mid)
        return query(ls(x), nl, mid, l, r);
    else if (l > mid)
        return query(rs(x), mid + 1, nr, l, r);
    else
        return max(query(ls(x), nl, mid, l, r), query(rs(x), mid + 1, nr, l, r));
}
} // namespace SMT

inline void insert(int x) {
    for (auto it : upd[x])
        SMT::update(1, 1, n, it.l, it.r, it.k);
}

inline void remove(int x) {
    for (auto it = upd[x].rbegin(); it != upd[x].rend(); ++it)
        SMT::update(1, 1, n, it->l, it->r, -it->k);
}

void solve(int l, int r, vector<int> &ask) {
    if (ask.empty())
        return;

    if (l == r) {
        insert(l), SMT::reset(1);

        for (int it : ask)
            ans[it] = max(ans[it], SMT::query(1, 1, n, qry[it].yl, qry[it].yr));

        remove(l);
        return;
    }

    int mid = (l + r) >> 1;
    vector<int> asknow, askl, askr;

    for (int it : ask) {
        if (qry[it].xr <= mid)
            askl.emplace_back(it);
        else if (qry[it].xl > mid)
            askr.emplace_back(it);
        else
            asknow.emplace_back(it);
    }

    solve(l, mid, askl);

    sort(asknow.begin(), asknow.end(), [](const int &a, const int &b) {
        return qry[a].xr < qry[b].xr;
    });

    for (int i = l; i <= mid; ++i)
        insert(i);

    auto it = asknow.begin();

    for (int i = mid + 1; i <= r; ++i) {
        insert(i);

        if (i == mid + 1)
            SMT::reset(1);

        for (; it != asknow.end() && qry[*it].xr == i; ++it)
            ans[*it] = max(ans[*it], SMT::query(1, 1, n, qry[*it].yl, qry[*it].yr));
    }

    for (int i = r; i > mid; --i)
        remove(i);

    solve(mid + 1, r, askr);

    sort(asknow.begin(), asknow.end(), [](const int &a, const int &b) {
        return qry[a].xl > qry[b].xl;
    });

    SMT::reset(1), it = asknow.begin();

    for (int i = mid; i >= l; --i) {
        for (; it != asknow.end() && qry[*it].xl == i; ++it)
            ans[*it] = max(ans[*it], SMT::query(1, 1, n, qry[*it].yl, qry[*it].yr));

        remove(i);
    }
}

signed main() {
    n = read(), m = read(), q = read();

    for (int i = 1; i <= m; ++i) {
        int xl = read(), yl = read(), xr = read(), yr = read(), k = read();
        upd[xl].emplace_back((Update){yl, yr, k});
        upd[xr + 1].emplace_back((Update){yl, yr, -k});
    }

    for (int i = 1; i <= n; ++i)
        sort(upd[i].begin(), upd[i].end());

    for (int i = 1; i <= q; ++i)
        qry[i].xl = read(), qry[i].yl = read(), qry[i].xr = read(), qry[i].yr = read();

    vector<int> ask(q);
    iota(ask.begin(), ask.end(), 1);
    solve(1, n, ask);

    for (int i = 1; i <= q; ++i)
        printf("%lld\n", ans[i]);

    return 0;
}

cdq 分治

对于一类问题,下标之间转移的贡献形式为左边转移到右边,然后需要查询一些信息。

考虑基于中点的分治结构,分治时分为三个部分处理:左区间内部、左区间对右区间、右区间内部。

一个保险的标准顺序是先处理左区间,再处理左区间对右区间的贡献,最后处理右区间,这样就可以保证时序性了。

注意这种写法在处理左区间对右区间贡献是要先按标号排序分出正确的左右区间,如果是先递归左右区间则不用。

统计点对相关问题

给定一个长度为 \(n\) 的序列,统计有一些特性的点对 \((i, j)\) 的数量或找到一对点 \((i, j)\) 使得函数值最大。

基本流程如下:

  • 找到序列中点 \(mid\)
  • 将所有点对分为三类:
    • \(i, j \in [l, mid]\)
    • \(i, j \in [mid + 1, r]\)
    • \(i \in [l, mid], j \in [mid + 1, r]\)
  • 将前两类分治处理,设法处理最后一类,一般为统计左区间对右区间的贡献。

P3810 【模板】三维偏序(陌上花开)

\(n\) 个元素,每个元素有 \(a_i, b_i, c_i\) 三个属性。设 \(f(i)\) 表示满足 \(a_j \leq a_i \and b_j \leq b_i \and c_j \leq c_i \and i \neq j\)\(j\) 的数量。对于所有 \(d \in [0, n)\) ,求 \(f(i) = d\)\(i\) 的数量。

\(n \leq 10^5\)\(a_i, b_i, c_i \leq 2 \times 10^5\)

先将序列按第一维排序,这样第一维偏序就解决了。

考虑计算 \([l, mid]\)\([mid + 1, r]\) 的贡献,此时只需要满足第二、三维的偏序关系,用 BIT 或再套一个 cdq 分治即可做到 \(O(n \log^2 n)\)

cdq 分治套 BIT:

#include <bits/stdc++.h>
using namespace std;
const int N = 1e5 + 7, V = 2e5 + 7;

struct Node {
    int a, b, c, cnt, ans;
} p[N], nd[N];

int ans[N];

int n, vlim, m;

namespace BIT {
int c[V];

inline void update(int x, int k) {
    for (; x <= vlim; x += x & -x)
        c[x] += k;
}

inline int query(int x) {
    int res = 0;
    
    for (; x; x -= x & -x)
        res += c[x];
    
    return res;
}
} // namespace BIT

void cdq(int l, int r) {
    if (l == r)
        return;
    
    int mid = (l + r) >> 1;
    cdq(l, mid), cdq(mid + 1, r);
    sort(nd + l, nd + mid + 1, [](const Node &a, const Node &b) { return a.b < b.b; });
    sort(nd + mid + 1, nd + r + 1, [](const Node &a, const Node &b) { return a.b < b.b; });
    int j = l;
    
    for (int i = mid + 1; i <= r; ++i) {
        for (; j <= mid && nd[j].b <= nd[i].b; ++j)
            BIT::update(nd[j].c, nd[j].cnt);
            
        nd[i].ans += BIT::query(nd[i].c);
    }
    
    for (--j; j >= l; --j)
        BIT::update(nd[j].c, -nd[j].cnt);
}

signed main() {
    scanf("%d%d", &n, &vlim);
    
    for (int i = 1; i <= n; ++i)
        scanf("%d%d%d", &p[i].a, &p[i].b, &p[i].c);
    
    sort(p + 1, p + 1 + n, [](const Node &a, const Node &b) {
        return a.a == b.a ? (a.b == b.b ? a.c < b.c : a.b < b.b) : a.a < b.a;
    });
    
    for (int i = 1, cnt = 1; i <= n; ++i, ++cnt)
        if (p[i].a != p[i + 1].a || p[i].b != p[i + 1].b || p[i].c != p[i + 1].c)
            nd[++m] = p[i], nd[m].cnt = cnt, cnt = 0;
    
    cdq(1, m);
    
    for (int i = 1; i <= m; ++i)
        ans[nd[i].ans + nd[i].cnt - 1] += nd[i].cnt;
    
    for (int i = 0; i < n; ++i)
        printf("%d\n", ans[i]);
    
    return 0;
}

P3157 [CQOI2011] 动态逆序对

给出一个 \(1 \sim n\) 的排列,按给出顺序依次删除 \(m\) 个元素,求每次删除一个元素之前整个序列的逆序对数。

\(n \leq 10^5\)\(m \leq 5 \times 10^4\)

逆序对本质就是二维偏序,再引入一个时间维即可转化为三维偏序。

\(t_i\) 表示 \(i\) 被删掉的时间,\(p_i\) 表示 \(i\) 的位置,则对于一个 \(i\)\(j\) 会产生贡献当且仅当 \(t_j \geq t_i\) 且满足 \(i > j \and p_i < p_j\)\(i < j \and p_i > p_j\)

不难发现这是一个三维偏序的形式,直接 cdq 统计可以做到 \(O(n \log^2 n)\)

#include <bits/stdc++.h>
typedef long long ll;
using namespace std;
const int N = 2e5 + 7;

struct Node {
    int p, k, t;
} nd[N];

ll ans[N];
int pos[N];

int n, m;

template <class T = int>
inline T read() {
    char c = getchar();
    bool sign = (c == '-');
    
    while (c < '0' || c > '9')
        c = getchar(), sign |= (c == '-');
    
    T x = 0;
    
    while ('0' <= c && c <= '9')
        x = (x << 1) + (x << 3) + (c & 15), c = getchar();
    
    return sign ? (~x + 1) : x;
}

namespace BIT {
int c[N];

inline void update(int x, int k) {
    for (; x <= n; x += x & -x)
        c[x] += k;
}

inline int query(int x) {
    int res = 0;
    
    for (; x; x -= x & -x)
        res += c[x];
    
    return res;
}
} // namespace BIT

void cdq(int l, int r) {
    if (l == r)
        return;

    int mid = (l + r) >> 1;
    cdq(l, mid), cdq(mid + 1, r);
    sort(nd + l, nd + mid + 1, [](const Node &a, const Node &b) { return a.p < b.p; });
    sort(nd + mid + 1, nd + r + 1, [](const Node &a, const Node &b) { return a.p < b.p; });
    int j = l;

    for (int i = mid + 1; i <= r; ++i) {
        for (; j <= mid && nd[j].p < nd[i].p; ++j)
            BIT::update(nd[j].k, 1);

        ans[nd[i].t] += BIT::query(n) - BIT::query(nd[i].k);
    }

    for (--j; j >= l; --j)
        BIT::update(nd[j].k, -1);

    j = mid;

    for (int i = r; i > mid; --i) {
        for (; j >= l && nd[j].p > nd[i].p; --j)
            BIT::update(nd[j].k, 1);

        ans[nd[i].t] += BIT::query(nd[i].k - 1);
    }

    for (++j; j <= mid; ++j)
        BIT::update(nd[j].k, -1);
}

signed main() {
    n = read(), m = read();

    for (int i = 1; i <= n; ++i) {
        int x = read();
        nd[x] = (Node) {i, x, m + 1};
    }

    for (int i = 1; i <= m; ++i)
        nd[read()].t = i;

    sort(nd + 1, nd + 1 + n, [](const Node &a, const Node &b) { return a.t > b.t; });
    cdq(1, n);

    for (int i = m; i; --i)
        ans[i] += ans[i + 1];

    for (int i = 1; i <= m; ++i)
        printf("%lld\n", ans[i]);

    return 0;
}

CF1045G AI robots

给出 \(n\) 个三元组 \((x_i, r_i, q_i)\) 以及常数 \(k\) ,求满足 \(|x_i - x_j| \leq \min(r_i, r_j)\)\(|q_i - q_j| \leq k\)\((i, j)\) 的数量。

\(n \leq 10^5\)\(k \leq 20\)

处理点对问题,考虑 cdq 分治。

先处理掉 \(\min(r_i, r_j)\) 的限制,按 \(r_i\) 降序排序,这样 \(\min(r_i, r_j)\) 就取的是右边的 \(r\)

再考虑 \(|x_i - x_j| \leq \min(r_i, r_j)\) 的限制,直接在 BIT 做一个区间查询即可。

最后考虑 \(|q_i - q_j| \leq k\) 的限制,统计贡献时套一个 two-pointers 即可。

时间复杂度 \(O(n \log^2 n)\)

#include <bits/stdc++.h>
typedef long long ll;
using namespace std;
const int N = 1e5 + 7;

struct Node {
    int x, range, q, l, r;
} nd[N];

ll ans;
int n, k;

namespace BIT {
int c[N];

inline void update(int x, int k) {
    for (; x <= n; x += x & -x)
        c[x] += k;
}

inline int query(int x) {
    int res = 0;
    
    for (; x; x -= x & -x)
        res += c[x];
    
    return res;
}
} // namespace BIT

void cdq(int l, int r) {
    if (l == r)
        return;

    int mid = (l + r) >> 1;
    cdq(l, mid), cdq(mid + 1, r);
    sort(nd + l, nd + mid + 1, [](const Node &a, const Node &b) { return a.q < b.q; });
    sort(nd + mid + 1, nd + r + 1, [](const Node &a, const Node &b) { return a.q < b.q; });
    int jl = l, jr = l - 1;

    for (int i = mid + 1; i <= r; ++i) {
        while (jl <= mid && nd[i].q - nd[jl].q > k)
            BIT::update(nd[jl++].x, -1);

        while (jr < mid && nd[jr + 1].q - nd[i].q <= k)
            BIT::update(nd[++jr].x, 1);

        ans += BIT::query(nd[i].r) - BIT::query(nd[i].l - 1);
    }

    for (int i = jl; i <= jr; ++i)
        BIT::update(nd[i].x, -1);
}

signed main() {
    scanf("%d%d", &n, &k);
    vector<int> vec;

    for (int i = 1; i <= n; ++i) {
        scanf("%d%d%d", &nd[i].x, &nd[i].range, &nd[i].q);
        vec.emplace_back(nd[i].x);
    }

    sort(vec.begin(), vec.end());
    vec.erase(unique(vec.begin(), vec.end()), vec.end());

    for (int i = 1; i <= n; ++i) {
        nd[i].l = lower_bound(vec.begin(), vec.end(), nd[i].x - nd[i].range) - vec.begin() + 1;
        nd[i].r = upper_bound(vec.begin(), vec.end(), nd[i].x + nd[i].range) - vec.begin();
        nd[i].x = lower_bound(vec.begin(), vec.end(), nd[i].x) - vec.begin() + 1;
    }

    sort(nd + 1, nd + 1 + n, [](const Node &a, const Node &b) { return a.range > b.range; });
    cdq(1, n);
    printf("%lld", ans);
    return 0;
}

P4169 [Violet] 天使玩偶/SJY摆棋子

在平面直角坐标系上维护 \(n\) 次操作,操作有:

  • 加入一个点 \((x, y)\)
  • 询问与 \((x, y)\) 曼哈顿距离最小的点。

\(n \leq 3 \times 10^5\)

考虑暴力分类讨论两个点的大小关系,拆掉绝对值取。记 \(t_i\) 表示第 \(i\) 个点的出现时间,则 \(x, y, t\) 构成三维偏序,直接套 cdq 分治板子即可做到 \(O(n \log^2 n)\)

#include <bits/stdc++.h>
using namespace std;
const int inf = 0x3f3f3f3f;
const int N = 3e5 + 7, V = 2e6 + 7;

struct Node {
    int x, y, id;
} nd[4][N << 1];

int ans[N];

int n, m, tot, cntq;

template <class T = int>
inline T read() {
    char c = getchar();
    bool sign = (c == '-');
    
    while (c < '0' || c > '9')
        c = getchar(), sign |= (c == '-');
    
    T x = 0;
    
    while ('0' <= c && c <= '9')
        x = (x << 1) + (x << 3) + (c & 15), c = getchar();
    
    return sign ? (~x + 1) : x;
}

namespace BIT {
int c[V];

inline void prework() {
    memset(c, -inf, sizeof(c));
}

inline void update(int x, int k) {
    for (; x < V; x += x & -x)
        c[x] = max(c[x], k);
}

inline void remove(int x) {
    for (; x < V; x += x & -x)
        c[x] = -inf;
}

inline int query(int x) {
    int res = -inf;
    
    for (; x; x -= x & -x)
        res = max(res, c[x]);
    
    return res;
}
} // namespace BIT

void solve(int l, int r, Node *nd) {
    if (l == r)
        return;
    
    int mid = (l + r) >> 1;
    solve(l, mid, nd), solve(mid + 1, r, nd);
    sort(nd + l, nd + mid + 1, [](const Node &a, const Node &b) { return a.x < b.x; });
    sort(nd + mid + 1, nd + r + 1, [](const Node &a, const Node &b) { return a.x < b.x; });
    int j = l;

    for (int i = mid + 1; i <= r; ++i) {
        for (; j <= mid && nd[j].x <= nd[i].x; ++j)
            if (!nd[j].id)
                BIT::update(nd[j].y, nd[j].x + nd[j].y);

        if (nd[i].id)
            ans[nd[i].id] = min(ans[nd[i].id], nd[i].x + nd[i].y - BIT::query(nd[i].y));
    }
    
    for (--j; j >= l; --j)
        BIT::remove(nd[j].y);
}

signed main() {
    n = read(), m = read();
    
    for (int i = 1; i <= n; ++i) {
        int x = read(), y = read() + 1;
        nd[0][i] = (Node) {x, y, 0};
        nd[1][i] = (Node) {V - x, y, 0};
        nd[2][i] = (Node) {x, V - y, 0};
        nd[3][i] = (Node) {V - x, V - y, 0};
    }
    
    for (int i = n + 1; i <= n + m; ++i) {
        int op = read(), x = read(), y = read() + 1;
        
        if (op == 1) {
            nd[0][i] = (Node) {x, y, 0};
            nd[1][i] = (Node) {V - x, y, 0};
            nd[2][i] = (Node) {x, V - y, 0};
            nd[3][i] = (Node) {V - x, V - y, 0};
        } else {
            ++cntq;
            nd[0][i] = (Node) {x, y, cntq};
            nd[1][i] = (Node) {V - x, y, cntq};
            nd[2][i] = (Node) {x, V - y, cntq};
            nd[3][i] = (Node) {V - x, V - y, cntq};
        }
    }
    
    memset(ans + 1, inf, sizeof(int) * cntq);
    BIT::prework();
    
    for (int i = 0; i < 4; ++i)
        solve(1, n + m, nd[i]);
    
    for (int i = 1; i <= cntq; ++i)
        printf("%d\n", ans[i]);
    
    return 0;
}

动态二维数点相关

静态二维数点问题常常使用主席树解决,对于动态的二维数点问题,若强制在线,可考虑树套树,否则可以尝试考虑 cdq 分治。

P4390 [BalkanOI2007] Mokia 摩基亚

给出一个 \(W \times W\) 的网格,\(n\) 次操作,每次操作为下面两种操作中的一种:

  • 给某个格子加上 \(x\)
  • 询问一个矩形中的所有数的和。

\(n \leq 10^5\)\(W \leq 10^6\)

先差分将查询转化为若干前缀矩阵查询,然后直接按时间 cdq 分治,排序+树状数组处理左区间修改对右区间查询的贡献即可做到 \(O(n \log^2 n)\)

#include <bits/stdc++.h>
using namespace std;
const int N = 2e6 + 7;

struct Node {
	int x, y, k, t, id;
} nd[N];

int ans[N];

int W, n, cntu, cntq;

template <class T = int>
inline T read() {
	char c = getchar();
	bool sign = c == '-';
	
	while (c < '0' || c > '9')
		c = getchar(), sign |= c == '-';
	
	T x = 0;
	
	while ('0' <= c && c <= '9')
		x = (x << 1) + (x << 3) + (c & 15), c = getchar();
	
	return sign ? (~x + 1) : x;
}

namespace BIT {
int c[N];

inline void update(int x, int k) {
	for (; x <= W; x += x & -x)
		c[x] += k;
}

inline int query(int x) {
	int res = 0;
	
	for (; x; x -= x & -x)
		res += c[x];
	
	return res;
}
} // namespace BIT

void cdq(int l, int r) {
	if (l == r)
		return;
	
	int mid = (l + r) >> 1;
	cdq(l, mid), cdq(mid + 1, r);
	sort(nd + l, nd + mid + 1, [](const Node &a, const Node &b) { return a.x < b.x; });
	sort(nd + mid + 1, nd + r + 1, [](const Node &a, const Node &b) { return a.x < b.x; });
	int j = l;

	for (int i = mid + 1; i <= r; ++i) {
		for (; j <= mid && nd[j].x <= nd[i].x; ++j)
			if (!nd[j].id)
				BIT::update(nd[j].y, nd[j].k);

		if (nd[i].id > 0)
			ans[nd[i].id] += BIT::query(nd[i].y);
		else if (nd[i].id < 0)
			ans[-nd[i].id] -= BIT::query(nd[i].y);
	}

	for (--j; j >= l; --j)
		if (!nd[j].id)
			BIT::update(nd[j].y, -nd[j].k);
}

signed main() {
	int op = read();
	W = read() + 1;
	
	while ((op = read()) != 3) {
		if (op == 1) {
			int x = read() + 1, y = read() + 1, k = read();
			++cntu, nd[++n] = (Node) {x, y, k, cntu, 0};
		} else {
			int x = read() + 1, y = read() + 1, xx = read() + 1, yy = read() + 1;
			++cntq;
			nd[++n] = (Node) {xx, yy, 0, cntu, cntq};
			nd[++n] = (Node) {x - 1, yy, 0, cntu, -cntq};
			nd[++n] = (Node) {xx, y - 1, 0, cntu, -cntq};
			nd[++n] = (Node) {x - 1, y - 1, 0, cntu, cntq};
		}
	}
    
	cdq(1, n);
	
	for (int i = 1; i <= cntq; ++i)
		printf("%d\n", ans[i]);
	
	return 0;
}

CF848C Goodbye Souvenir

给出序列 \(a_{1 \sim n}\)\(m\) 次操作:

  • 单点修改。
  • 查询 \([l, r]\) 内所有数字价值和,其中 \(x\) 的价值定义为 \(x\)\([l, r]\) 内第一次与最后一次出现位置的下标差,相同数字不重复贡献。

\(n, m \leq 10^5\)

\(a_i\) 上一个出现位置为 \(pre_i\) ,则答案可以写作:

\[\sum_{l \leq pre_i < i \leq r} i - pre_i \]

不难发现这是一个矩形查询的形式,而一次修改可以拆分为若干次二维平面上的单点修改。由于不要求强制在线,直接用 cdq 分治解决时间维度的偏序关系即可。

由于 \(i > pre_i\) ,所以矩形查询时没必要差分,只要查 \(pre_i \geq l, i \leq r\) 的点即可。

时间复杂度 \(O(n \log^2 n)\)

#include <bits/stdc++.h>
typedef long long ll;
using namespace std;
const int N = 1e5 + 7;

struct Node {
    int op, x, y, k, t;
} nd[N * 7];

set<int> place[N];

ll ans[N];
int a[N];

int n, m, cnt, cntq;

template <class T = int>
inline T read() {
    char c = getchar();
    bool sign = (c == '-');
    
    while (c < '0' || c > '9')
        c = getchar(), sign |= (c == '-');
    
    T x = 0;
    
    while ('0' <= c && c <= '9')
        x = (x << 1) + (x << 3) + (c & 15), c = getchar();
    
    return sign ? (~x + 1) : x;
}

struct BIT {
    ll c[N];

    inline void update(int x, ll k) {
        for (; x <= n; x += x & -x)
            c[x] += k;
    }

    inline ll query(int x) {
        ll res = 0;

        for (; x; x -= x & -x)
            res += c[x];

        return res;
    }
} bit;

void cdq(int l, int r) {
    if (l == r)
        return;

    int mid = (l + r) >> 1;
    cdq(l, mid), cdq(mid + 1, r);
    sort(nd + l, nd + mid + 1, [](const Node &a, const Node &b) { return a.x > b.x; });
    sort(nd + mid + 1, nd + r + 1, [](const Node &a, const Node &b) { return a.x > b.x; });
    int j = l;

    for (int i = mid + 1; i <= r; ++i) {
        while (j <= mid && nd[j].x >= nd[i].x) {
            if (nd[j].op == 1)
                bit.update(nd[j].y, (nd[j].y - nd[j].x) * nd[j].k);

            ++j;
        }

        if (nd[i].op == 2)
            ans[nd[i].k] += bit.query(nd[i].y);
    }

    for (--j; j >= l; --j)
        if (nd[j].op == 1)
            bit.update(nd[j].y, -(nd[j].y - nd[j].x) * nd[j].k);
}

signed main() {
    n = read(), m = read();

    for (int i = 1; i <= n; ++i) {
        a[i] = read();

        if (!place[a[i]].empty())
            nd[++cnt] = (Node) {1, *place[a[i]].rbegin(), i, 1, 0};

        place[a[i]].emplace(i);
    }

    for (int i = 1; i <= m; ++i) {
        if (read() == 1) {
            int x = read(), k = read();
            auto it = place[a[x]].find(x);

            if (it != place[a[x]].begin())
                nd[++cnt] = (Node) {1, *prev(it), x, -1, i};

            if (next(it) != place[a[x]].end())
                nd[++cnt] = (Node) {1, x, *next(it), -1, i};

            if (it != place[a[x]].begin() && next(it) != place[a[x]].end())
                nd[++cnt] = (Node) {1, *prev(it), *next(it), 1, i};

            place[a[x]].erase(x), it = place[a[x] = k].emplace(x).first;

            if (it != place[k].begin() && next(it) != place[k].end())
                nd[++cnt] = (Node) {1, *prev(it), *next(it), -1, i};

            if (it != place[k].begin())
                nd[++cnt] = (Node) {1, *prev(it), x, 1, i};

            if (next(it) != place[k].end())
                nd[++cnt] = (Node) {1, x, *next(it), 1, i};
        } else {
            int l = read(), r = read();
            nd[++cnt] = (Node) {2, l, r, ++cntq, i};
        }
    }

    cdq(1, cnt);

    for (int i = 1; i <= cntq; ++i)
        printf("%lld\n", ans[i]);

    return 0;
}

优化 1D/1D 动态规划

以二维 LIS 为例,不难列出转移方程:

\[f_i = 1 + \max_{j = 1}^{i - 1} f_j [a_j < a_i] [b_j < b_i] \]

直接转移是 \(O(n^2)\) 的。考虑 cdq 分治优化,假设当前处理的区间为 \([l, r]\) ,流程大致如下:

  • \(l = r\) ,说明 \(f_l\) 已求得,直接返回即可。
  • 递归处理 \([l, mid]\)
  • 处理所有 \([l, mid] \to [mid + 1, r]\) 的转移关系。
  • 递归处理 \([mid + 1, r]\)

注意 DP 的转移是有时序性的,必须按标准顺序处理。

P2487 [SDOI2011] 拦截导弹

\(n\) 个导弹,每个导弹有两个参数 \(h, v\) 。求一个最长的序列 \(a\) ,满足 \(h, v\) 不升,输出其长度。并对于每个导弹,求出其成为最长序列中的一项的概率。

\(n \leq 5 \times 10^4\)

第一问和二维 LIS 是类似的,第二问实际就是包含该导弹的方案数除以总方案数。

一个显然的事实是包含该导弹的方案数为前后最长序列的方案数的乘积,于是跑正反两遍 cdq 即可,时间复杂度 \(O(n \log^2 n)\)

#include <bits/stdc++.h>
using namespace std;
const int N = 5e4 + 7;

struct Node {
    int h, v, id;
    pair<int, double> f;
};

int n;

template <class T = int>
inline T read() {
    char c = getchar();
    bool sign = (c == '-');
    
    while (c < '0' || c > '9')
        c = getchar(), sign |= (c == '-');
    
    T x = 0;
    
    while ('0' <= c && c <= '9')
        x = (x << 1) + (x << 3) + (c & 15), c = getchar();
    
    return sign ? (~x + 1) : x;
}

inline pair<int, double> cmp(pair<int, double> a, pair<int, double> b) {
    if (a.first < b.first)
        swap(a, b);
    else if (b.first == a.first)
        a.second += b.second;

    return a;
}

namespace Pre {
Node nd[N];

namespace BIT {
pair<int, double> c[N];

inline void update(int x, pair<int, double> k) {
    for (; x; x -= x & -x)
        c[x] = cmp(c[x], k);
}

inline void remove(int x) {
    for (; x; x -= x & -x)
        c[x] = make_pair(0, 0);
}

inline pair<int, double> query(int x) {
    pair<int, double> res = make_pair(0, 0);

    for (; x <= n; x += x & -x)
        res = cmp(res, c[x]);

    return res;
}
} // namespace BIT

void cdq(int l, int r) {
    if (l == r) {
        ++nd[l].f.first;
        return;
    }

    int mid = (l + r) >> 1;
    cdq(l, mid);
    sort(nd + l, nd + mid + 1, [](const Node &a, const Node &b) { return a.h > b.h; });
    sort(nd + mid + 1, nd + r + 1, [](const Node &a, const Node &b) { return a.h > b.h; });
    int j = l;

    for (int i = mid + 1; i <= r; ++i) {
        for (; j <= mid && nd[j].h >= nd[i].h; ++j)
            BIT::update(nd[j].v, nd[j].f);

        nd[i].f = cmp(nd[i].f, BIT::query(nd[i].v));
    }

    for (--j; j >= l; --j)
        BIT::remove(nd[j].v);

    sort(nd + mid + 1, nd + r + 1, [](const Node &a, const Node &b) { return a.id < b.id; });
    cdq(mid + 1, r);
}
} // namespace Pre

namespace Suf {
Node nd[N];

namespace BIT {
pair<int, double> c[N];

inline void update(int x, pair<int, double> k) {
    for (; x <= n; x += x & -x)
        c[x] = cmp(c[x], k);
}

inline void remove(int x) {
    for (; x <= n; x += x & -x)
        c[x] = make_pair(0, 0);
}

inline pair<int, double> query(int x) {
    pair<int, double> res = make_pair(0, 0);

    for (; x; x -= x & -x)
        res = cmp(res, c[x]);

    return res;
}
} // namespace BIT

void cdq(int l, int r) {
    if (l == r) {
        ++nd[l].f.first;
        return;
    }

    int mid = (l + r) >> 1;
    cdq(mid + 1, r);
    sort(nd + l, nd + mid + 1, [](const Node &a, const Node &b) { return a.h < b.h; });
    sort(nd + mid + 1, nd + r + 1, [](const Node &a, const Node &b) { return a.h < b.h; });
    int j = mid + 1;

    for (int i = l; i <= mid; ++i) {
        for (; j <= r && nd[j].h <= nd[i].h; ++j)
            BIT::update(nd[j].v, nd[j].f);

        nd[i].f = cmp(nd[i].f, BIT::query(nd[i].v));
    }

    for (--j; j >= mid + 1; --j)
        BIT::remove(nd[j].v);

    sort(nd + l, nd + mid + 1, [](const Node &a, const Node &b) { return a.id < b.id; });
    cdq(l, mid);
}
} // namespace Suf

signed main() {
    n = read();
    vector<int> vec;

    for (int i = 1; i <= n; ++i) {
        Pre::nd[i].h = Suf::nd[i].h = read();
        vec.emplace_back(Pre::nd[i].v = Suf::nd[i].v = read());
        Pre::nd[i].id = Suf::nd[i].id = i;
    }

    sort(vec.begin(), vec.end());
    vec.erase(unique(vec.begin(), vec.end()), vec.end());

    for (int i = 1; i <= n; ++i) {
        Pre::nd[i].v = lower_bound(vec.begin(), vec.end(), Pre::nd[i].v) - vec.begin() + 1;
        Suf::nd[i].v = lower_bound(vec.begin(), vec.end(), Suf::nd[i].v) - vec.begin() + 1;
        Pre::nd[i].f = Suf::nd[i].f = make_pair(0, 1);
    }

    Pre::cdq(1, n), Suf::cdq(1, n);
    sort(Pre::nd + 1, Pre::nd + 1 + n, [](const Node &a, const Node &b) { return a.id < b.id; });
    sort(Suf::nd + 1, Suf::nd + 1 + n, [](const Node &a, const Node &b) { return a.id < b.id; });
    pair<int, double> ans = make_pair(0, 1);

    for (int i = 1; i <= n; ++i)
        ans = cmp(ans, Pre::nd[i].f);

    printf("%d\n", ans.first);

    for (int i = 1; i <= n; ++i) {
        if (Pre::nd[i].f.first + Suf::nd[i].f.first - 1 == ans.first)
            printf("%.5lf ", Pre::nd[i].f.second * Suf::nd[i].f.second / ans.second);
        else
            printf("0.00000 ");
    }

    return 0;
}

P4849 寻找宝藏

在一个四维坐标系中,给定 \(n\) 个点,问有多少种选择点的方案,使得这些点排序后任意坐标单调不降,最大化选择的点权和,并输出方案数。

\(n \leq 8 \times 10^4\)

即四维 LIS ,弱化版:

cdq 套 cdq 套 BIT 即可做到四维偏序型转移,时间复杂度 \(O(n \log^3 n)\)

#include <bits/stdc++.h>
typedef long long ll;
using namespace std;
const int Mod = 998244353;
const int N = 8e4 + 7;

struct Node {
    pair<ll, int> f;

    int a, b, c, d, flag;
    ll k;

    inline bool operator == (const Node &rhs) const {
        return a == rhs.a && b == rhs.b && c == rhs.c && d == rhs.d;
    }
} p[N], nd[N];

int n, m;

inline int add(int x, int y) {
    x += y;
    
    if (x >= Mod)
        x -= Mod;
    
    return x;
}

inline pair<ll, int> cmp(pair<ll, int> a, pair<ll, int> b) {
    if (a.first < b.first)
        swap(a, b);
    else if (a.first == b.first)
        a.second = add(a.second, b.second);

    return a;
}

namespace BIT {
pair<ll, int> c[N];

inline void update(int x, pair<ll, int> k) {
    for (; x <= m; x += x & -x)
        c[x] = cmp(c[x], k);
}

inline pair<ll, int> query(int x) {
    pair<ll, int> res = make_pair(0, 0);

    for (; x; x -= x & -x)
        res = cmp(res, c[x]);

    return res;
}

inline void remove(int x) {
    for (; x <= m; x += x & -x)
        c[x] = make_pair(0, 0);
}
} // namespace BIT

void cdq2(int l, int r) {
    if (l == r)
        return;

    int mid = (l + r) >> 1;
    cdq2(l, mid);

    stable_sort(nd + l, nd + mid + 1, [](const Node &a, const Node &b) {
        return a.c == b.c ? a.d < b.d : a.c < b.c;
    });

    stable_sort(nd + mid + 1, nd + r + 1, [](const Node &a, const Node &b) {
        return a.c == b.c ? a.d < b.d : a.c < b.c;
    });

    int j = l;

    for (int i = mid + 1; i <= r; ++i) {
        for (; j <= mid && nd[j].c <= nd[i].c; ++j)
            if (!nd[j].flag)
                BIT::update(nd[j].d, nd[j].f);

        if (nd[i].flag)
            nd[i].f = cmp(nd[i].f, BIT::query(nd[i].d));
    }

    for (--j; j >= l; --j)
        if (!nd[j].flag)
            BIT::remove(nd[j].d);

    stable_sort(nd + l, nd + r + 1, [](const Node &a, const Node &b) {
        return a.b == b.b ? (a.c == b.c ? a.d < b.d : a.c < b.c) : a.b < b.b;
    });

    cdq2(mid + 1, r);
}

void cdq1(int l, int r) {
    if (l == r) {
        nd[l].f.first += nd[l].k;
        return;
    }

    int mid = (l + r) >> 1;
    cdq1(l, mid);

    for (int i = l; i <= r; ++i)
        nd[i].flag = (i > mid);

    stable_sort(nd + l, nd + r + 1, [](const Node &a, const Node &b) {
        return a.b == b.b ? (a.c == b.c ? a.d < b.d : a.c < b.c) : a.b < b.b;
    });

    cdq2(l, r);

    stable_sort(nd + l, nd + r + 1, [](const Node &a, const Node &b) {
        return a.a == b.a ? (a.b == b.b ? (a.c == b.c ? a.d < b.d : a.c < b.c) : a.b < b.b) : a.a < b.a;
    });

    cdq1(mid + 1, r);
}

signed main() {
    scanf("%d%*d", &n);

    for (int i = 1; i <= n; ++i)
        scanf("%d%d%d%d%lld", &p[i].a, &p[i].b, &p[i].c, &p[i].d, &p[i].k);

    stable_sort(p + 1, p + n + 1, [](const Node &a, const Node &b) {
        return a.a == b.a ? (a.b == b.b ? (a.c == b.c ? a.d < b.d : a.c < b.c) : a.b < b.b) : a.a < b.a;
    });

    nd[m = 1] = p[1];

    for (int i = 2; i <= n; ++i) {
        if (p[i] == nd[m])
            nd[m].k += p[i].k;
        else
            nd[++m] = p[i];
    }

    vector<int> vec;

    for (int i = 1; i <= m; ++i)
        vec.emplace_back(nd[i].d);

    stable_sort(vec.begin(), vec.end()), vec.erase(unique(vec.begin(), vec.end()), vec.end());

    for (int i = 1; i <= m; ++i)
        nd[i].d = lower_bound(vec.begin(), vec.end(), nd[i].d) - vec.begin() + 1, nd[i].f = make_pair(0, 1);

    cdq1(1, m);
    pair<ll, int> ans = make_pair(0, 0);

    for (int i = 1; i <= m; ++i)
        ans = cmp(ans, nd[i].f);

    printf("%lld\n%d", ans.first, ans.second);
    return 0;
}

Safe Partition

给定 \(a_{1 \sim n}\) ,需要将其划分为若干段,每段 \(S\) 均要满足 \(\min_{x \in S} a_x \leq |S| \leq \max_{x \in S} a_x\) ,求划分方案数 \(\bmod 10^9 + 7\)

\(n \leq 5 \times 10^5\)

\(f_i\) 表示 \(1 \sim i\) 的划分方案数,考虑 cdq 分治,每次处理 \(f_{l - 1, mid - 1} \to f_{mid + 1, r}\) 的转移,即 \(\forall i \in [l, mid], j \in [mid + 1, r]\)\([i, j]\) 合法时转移 \(f_{i - 1} \to f_j\)

考虑枚举 \(j\) ,升序枚举 \(j\)\(\min\) 不升,因此只考虑最小值限制时合法的 \(i\)\([l, mid]\) 的一段前缀,并且前缀是不断扩大的,可以维护指针做到单层线性。

接下来考虑 \(\max\) 的限制,不难发现存在一个分界点 \(p\) 使得 \(i < p\) 时区间 \(\max\) 在左边取到,\(i \geq p\) 时区间 \(\max\) 在右边取到。对于后一种情况,右边的最大值是已知的,那么合法的 \(i\) 是一段后缀。对于前者,不用考虑,只要枚举 \(i\) 做最大值在左区间的类似转移即可。

使用前缀和与差分维护区间修改、查询操作,注意区分两边最大值相同的情况(直接加一个下标作为第二关键字即可),时间复杂度 \(O(n \log n)\)

#include <bits/stdc++.h>
using namespace std;
const int Mod = 1e9 + 7;
const int N = 5e5 + 7;

int a[N], f[N], s[N], c[N], mn[N];

int n;

inline int add(int x, int y) {
    x += y;
    
    if (x >= Mod)
        x -= Mod;
    
    return x;
}

inline int dec(int x, int y) {
    x -= y;
    
    if (x < 0)
        x += Mod;
    
    return x;
}

void solve(int l, int r) {
    if (l == r) {
        if (a[l] == 1)
            f[l] = add(f[l], f[l - 1]);

        s[l] = add(s[l - 1], f[l] = add(f[l], c[l] = add(c[l], c[l - 1])));
        return;
    }

    int mid = (l + r) >> 1;
    solve(l, mid);

    mn[mid] = a[mid];

    for (int i = mid - 1; i >= l; --i)
        mn[i] = min(mn[i + 1], a[i]);

    mn[mid + 1] = a[mid + 1];

    for (int i = mid + 1; i <= r; ++i)
        mn[i] = min(mn[i - 1], a[i]);

    for (int i = mid, mx = a[mid], j = mid, k = r + 1; i >= l; mx = max(mx, a[--i])) {
        while (j + 1 <= r && a[j + 1] <= mx)
            ++j;

        while (k - 1 > mid && min(mn[k - 1], mn[i]) <= k - i)
            --k;

        int p = min(j, i + mx - 1);

        if (k <= p)
            c[k] = add(c[k], f[i - 1]), c[p + 1] = dec(c[p + 1], f[i - 1]);
    }

    for (int i = mid + 1, mx = a[mid + 1], j = mid + 1, k = l - 1; i <= r; mx = max(mx, a[++i])) {
        while (j - 1 >= l && a[j - 1] < mx)
            --j;

        while (k + 1 <= mid && min(mn[k + 1], mn[i]) <= i - k)
            ++k;

        int p = max(j, i - mx + 1);

        if (p <= k) {
            if (p >= 2)
                f[i] = add(f[i], dec(s[k - 1], s[p - 2]));
            else
                f[i] = add(f[i], s[k - 1]);
        }
    }

    solve(mid + 1, r);
}

signed main() {
    scanf("%d", &n);

    for (int i = 1; i <= n; ++i)
        scanf("%d", a + i);

    s[0] = f[0] = 1, solve(1, n);
    printf("%d", f[n]);
    return 0;
}

将动态问题转化为静态问题

对于一类带修改与查询且可以离线的问题,若不好用 DS 维护信息,则可以考虑 将所有操作会按照时间 cdq 分治。

假设现在处理的时间区间为 \([l, r]\) ,先递归处理 \([l, mid]\)\([mid + 1, r]\) 的修改-询问关系,再处理左区间修改-右区间询的关系,统计这部分修改对询问的贡献。

如果修改之间相互独立,则三部分顺序无所谓,否则必须按标准顺序处理。

整体二分

一类题目具有如下特征:

  • 答案可以二分求得,但是对于多组询问,如果每次都二分,则每次需要统计所有修改-询问的关系,复杂度难以接受。
  • 允许离线。
  • 修改对判定答案的贡献互相独立,修改之间互相独立。
  • 修改如果对判定答案有贡献,则贡献与判定标准无关。

首先把所有操作按时间顺序存入数组中,然后开始分治。

记函数 solve(l, r, L, R) 表示操作 \([L, R]\) 的答案在 \([l, r]\) 中。

\(l = r\) ,则说明找到答案。否则在每一层分治中,利用数据结构统计当前查询的答案和 \(mid\) 之间的关系,将当前处理的操作序列分为两部分并分别递归处理。

若分治中用线性结构维护,时间复杂度 \(O(n \log V)\)

求解 k 小值

P2617 Dynamic Rankings

给出 \(a_{1 \sim n}\)\(m\) 次操作:

  • 修改 \(a_x\)\(k\)
  • 询问 \(a_{l \sim r}\)\(k\) 小值。

\(n, m \leq 10^5\)

先将初始的 \(a_{1 \sim n}\) 转化为 \(n\) 次修改,后面的修改操作可以视为删掉原来的值再加上新的值。

处理当前层时只要将 \(\leq mid\) 的值做单点加一,然后查询就是查询前缀和,不难用 BIT 维护做到 \(O(n \log n \log V)\)

#include <bits/stdc++.h>
using namespace std;
const int inf = 0x3f3f3f3f;
const int N = 1e5 + 7;

struct Node {
    int op, l, r, k, id;
} nd[N * 3], ndl[N * 3], ndr[N * 3];

int a[N], ans[N];

int n, m, tot, cntq;

template <class T = int>
inline T read() {
    char c = getchar();
    bool sign = (c == '-');
    
    while (c < '0' || c > '9')
        c = getchar(), sign |= (c == '-');
    
    T x = 0;
    
    while ('0' <= c && c <= '9')
        x = (x << 1) + (x << 3) + (c & 15), c = getchar();
    
    return sign ? (~x + 1) : x;
}

inline char readc() {
    char c = getchar();

    while (c != 'Q' && c != 'C')
        c = getchar();

    return c;
}

namespace BIT {
int c[N];

inline void update(int x, int k) {
    for (; x <= n; x += x & -x)
        c[x] += k;
}

inline int query(int x) {
    int res = 0;
    
    for (; x; x -= x & -x)
        res += c[x];
    
    return res;
}
} // namespace BIT

void solve(int l, int r, int L, int R) {
    if (L > R)
        return;
    
    if (l == r) {
        for (int i = L; i <= R; ++i)
            if (nd[i].op == 3)
                ans[nd[i].id] = l;
        
        return;
    }
    
    int mid = (l + r) >> 1, lp = 0, rp = 0;
    
    for (int i = L; i <= R; ++i) {
        if (nd[i].op == 1) {
            if (nd[i].k <= mid)
                ndl[lp++] = nd[i], BIT::update(nd[i].l, 1);
            else
                ndr[rp++] = nd[i];
        } else if (nd[i].op == 2) {
            if (nd[i].k <= mid)
                ndl[lp++] = nd[i], BIT::update(nd[i].l, -1);
            else
                ndr[rp++] = nd[i];
        } else {
            int x = BIT::query(nd[i].r) - BIT::query(nd[i].l - 1);
            
            if (nd[i].k <= x)
                ndl[lp++] = nd[i];
            else
                nd[i].k -= x, ndr[rp++] = nd[i];
        }
    }
    
    for (int i = 0; i < lp; ++i) {
        if (ndl[i].op == 1 && ndl[i].k <= mid)
            BIT::update(ndl[i].l, -1);
        else if (ndl[i].op == 2 && ndl[i].k <= mid)
            BIT::update(ndl[i].l, 1);
    }

    memcpy(nd + L, ndl, sizeof(Node) * lp);
    memcpy(nd + L + lp, ndr, sizeof(Node) * rp);
    solve(l, mid, L, L + lp - 1), solve(mid + 1, r, L + lp, R);
}

signed main() {
    n = read(), m = read();
    
    for (int i = 1; i <= n; ++i)
        nd[++tot] = (Node) {1, i, 0, a[i] = read(), 0};

    for (int i = 1; i <= m; ++i) {
        if (readc() == 'C') {
            int x = read(), k = read();
            nd[++tot] = (Node) {2, x, 0, a[x], 0};
            nd[++tot] = (Node) {1, x, 0, a[x] = k, 0};
        } else {
            int l = read(), r = read(), k = read();
            nd[++tot] = (Node) {3, l, r, k, ++cntq};
        }
    }
    
    solve(-inf, inf, 1, tot);
    
    for (int i = 1; i <= cntq; ++i)
        printf("%d\n", ans[i]);
    
    return 0;
}

P3332 [ZJOI2013] K大数查询

维护 \(n\) 个可重集,初始均为空。\(m\) 次操作:

  • \(k\) 加入到编号在 \([l, r]\) 内的集合中。
  • 查询编号在 \([l, r]\) 内的集合的并集的第 \(k\) 大值。

注意可重集的并是不去除重复元素的。

\(n, m \leq 5 \times 10^4\)

用线段树维护区间加、区间和即可做到 \(O(n \log n \log V)\)

#include <bits/stdc++.h>
typedef long long ll;
using namespace std;
const int N = 5e4 + 7;

struct Node {
    ll k;
    int l, r, id;
} nd[N], ndl[N], ndr[N];

int ans[N];

int n, m, cntq;

template <class T = int>
inline T read() {
    char c = getchar();
    bool sign = (c == '-');
    
    while (c < '0' || c > '9')
        c = getchar(), sign |= (c == '-');
    
    T x = 0;
    
    while ('0' <= c && c <= '9')
        x = (x << 1) + (x << 3) + (c & 15), c = getchar();
    
    return sign ? (~x + 1) : x;
}

namespace SMT {
ll s[N << 2];
int tag[N << 2];

inline int ls(int x) {
    return x << 1;
}

inline int rs(int x) {
    return x << 1 | 1;
}

void update(int x, int nl, int nr, int l, int r, int k) {
    s[x] += 1ll * (min(r, nr) - max(l, nl) + 1) * k;

    if (l <= nl && nr <= r) {
        tag[x] += k;
        return;
    }

    int mid = (nl + nr) >> 1;

    if (l <= mid)
        update(ls(x), nl, mid, l, r, k);

    if (r > mid)
        update(rs(x), mid + 1, nr, l, r, k);
}

ll query(int x, int nl, int nr, int l, int r) {
    if (l <= nl && nr <= r)
        return s[x];

    int mid = (nl + nr) >> 1;
    ll res = 1ll * (min(r, nr) - max(l, nl) + 1) * tag[x];

    if (l <= mid)
        res += query(ls(x), nl, mid, l, r);
    
    if (r > mid)
        res += query(rs(x), mid + 1, nr, l, r);

    return res;
}
} // namespace SMT

void solve(int l, int r, int L, int R) {
    if (L > R)
        return;

    if (l == r) {
        for (int i = L; i <= R; ++i)
            if (nd[i].id)
                ans[nd[i].id] = l;

        return;
    }

    int mid = (l + r) >> 1, lp = 0, rp = 0;

    for (int i = L; i <= R; ++i) {
        if (nd[i].id) {
            ll res = SMT::query(1, 1, n, nd[i].l, nd[i].r);

            if (res < nd[i].k)
                nd[i].k -= res, ndl[lp++] = nd[i];
            else
                ndr[rp++] = nd[i];
        } else {
            if (nd[i].k <= mid)
                ndl[lp++] = nd[i];
            else
                SMT::update(1, 1, n, nd[i].l, nd[i].r, 1), ndr[rp++] = nd[i];
        }
    }

    for (int i = L; i <= R; ++i)
        if (!nd[i].id && nd[i].k > mid)
            SMT::update(1, 1, n, nd[i].l, nd[i].r, -1);

    memcpy(nd + L, ndl, sizeof(Node) * lp);
    memcpy(nd + L + lp, ndr, sizeof(Node) * rp);
    solve(l, mid, L, L + lp - 1), solve(mid + 1, r, L + lp, R);
}

signed main() {
    n = read(), m = read();

    for (int i = 1; i <= m; ++i)
        nd[i].id = (read() == 1 ? 0 : ++cntq), nd[i].l = read(), nd[i].r = read(), nd[i].k = read<ll>();

    solve(-n, n, 1, m);

    for (int i = 1; i <= cntq; ++i)
        printf("%d\n", ans[i]);

    return 0;
}

P1527 [国家集训队] 矩阵乘法

给出一个 \(n \times n\) 的矩阵,\(q\) 次询问一个子矩形的 \(k\) 小值。

\(n \leq 500\)\(q \leq 6 \times 10^4\)

用二维树状数组维护答案与 \(mid\) 的关系即可,时间复杂度 \(O((n^2 + q) \log^2 n \log V)\)

#include <bits/stdc++.h>
using namespace std;
const int inf = 0x3f3f3f3f;
const int N = 5e2 + 7, M = 6e4 + 7;

struct Node {
    int x, y, xx, yy, k, id;
} nd[N * N + M], ndl[N * N + M], ndr[N * N + M];

int ans[M];

int n, m, tot;

template <class T = int>
inline T read() {
    char c = getchar();
    bool sign = (c == '-');
    
    while (c < '0' || c > '9')
        c = getchar(), sign |= (c == '-');
    
    T x = 0;
    
    while ('0' <= c && c <= '9')
        x = (x << 1) + (x << 3) + (c & 15), c = getchar();
    
    return sign ? (~x + 1) : x;
}

namespace BIT {
int c[N][N];

inline void update(int x, int y, int k) {
    for (int i = x; i <= n; i += i & -i)
        for (int j = y; j <= n; j += j & -j)
            c[i][j] += k;
}

inline int ask(int x, int y) {
    int res = 0;

    for (int i = x; i; i -= i & -i)
        for (int j = y; j; j -= j & -j)
            res += c[i][j];

    return res;
}

inline int query(int x, int y, int xx, int yy) {
    return ask(xx, yy) - ask(x - 1, yy) - ask(xx, y - 1) + ask(x - 1, y - 1);
}
} // namespace BIT

void solve(int l, int r, int L, int R) {
    if (L > R)
        return;

    if (l == r) {
        for (int i = L; i <= R; ++i)
            if (nd[i].id)
                ans[nd[i].id] = l;

        return;
    }

    int mid = (l + r) >> 1, lp = 0, rp = 0;

    for (int i = L; i <= R; ++i) {
        if (nd[i].id) {
            int res = BIT::query(nd[i].x, nd[i].y, nd[i].xx, nd[i].yy);

            if (nd[i].k <= res)
                ndl[lp++] = nd[i];
            else
                nd[i].k -= res, ndr[rp++] = nd[i];
        } else {
            if (nd[i].k <= mid)
                BIT::update(nd[i].x, nd[i].y, 1), ndl[lp++] = nd[i];
            else
                ndr[rp++] = nd[i];
        }
    }

    for (int i = L; i <= R; ++i)
        if (!nd[i].id && nd[i].k <= mid)
            BIT::update(nd[i].x, nd[i].y, -1);

    memcpy(nd + L, ndl, sizeof(Node) * lp);
    memcpy(nd + L + lp, ndr, sizeof(Node) * rp);
    solve(l, mid, L, L + lp - 1), solve(mid + 1, r, L + lp, R);
}

signed main() {
    scanf("%d%d", &n, &m);

    for (int i = 1; i <= n; ++i)
        for (int j = 1; j <= n; ++j)
            nd[++tot].x = i, nd[tot].y = j, scanf("%d", &nd[tot].k);

    for (int i = 1; i <= m; ++i)
        nd[++tot].id = i, scanf("%d%d%d%d%d", &nd[tot].x, &nd[tot].y, &nd[tot].xx, &nd[tot].yy, &nd[tot].k);

    solve(-inf, inf, 1, tot);

    for (int i = 1; i <= m; ++i)
        printf("%d\n", ans[i]);

    return 0;
}

构造单调序列

P4331 [BalticOI 2004] Sequence 数字序列

给定一个整数序列 \(a_{1 \sim n}\),求出一个严格递增序列 \(b_{1 \sim n}\),使得 \(\sum_{i = 1}^n |a_i - b_i|\) 最小。

\(n \leq 10^6\)

先用一个经典套路,令 \(a_i \gets a_i - i\) ,最后令 \(b_i \gets b_i + i\) ,这样限制条件就转化为 \(b\) 单调不降。

事实上在满足操作次数最小化的前提下,一定存在一种方案使得最后序列中的每个数都是序列修改前存在的,可以使用数学归纳法证明。

由于要求 \(b\) 单调不降,考虑整体二分。记函数 solve(l, r, L, R) 判定最终序列区间 \([L, R]\) 的值域,此时可行值域为 \([l, r]\)

每轮二分开始时默认将所有数划分到 \([mid + 1, r]\) ,即划分到 \([l, mid]\) 的数设为 \(0\) 个。初始代价设为将序列区间 \([L, R]\) 全部置为 \(mid + 1\) 的操作次数。然后依次枚举 \([L, R]\) 中的数 \(i\) ,并计算将 \([L, i]\) 置为 \(mid\) 、将 \([i + 1, R]\) 置为 \(mid + 1\) 的操作次数之和,如果优于之前的操作次数则更新最少操作次数和要划分到 \([l, mid]\) 的数的个数。

时间复杂度 \(O(n \log V)\)

#include <bits/stdc++.h>
typedef long long ll;
using namespace std;
const int N = 1e6 + 7;

int a[N], b[N];

int n;

template <class T = int>
inline T read() {
    char c = getchar();
    bool sign = (c == '-');
    
    while (c < '0' || c > '9')
        c = getchar(), sign |= (c == '-');
    
    T x = 0;
    
    while ('0' <= c && c <= '9')
        x = (x << 1) + (x << 3) + (c & 15), c = getchar();
    
    return sign ? (~x + 1) : x;
}

void solve(ll l, ll r, int L, int R) {
    if (L > R)
        return;

    if (l == r) {
        fill(b + L, b + R + 1, l);
        return;
    }

    int mid = (l + r) >> 1;
    ll sum = 0;

    for (int i = L; i <= R; ++i)
        sum += abs(a[i] - mid - 1);

    ll mn = sum;
    int mnp = L - 1;

    for (int i = L; i <= R; ++i) {
        sum += abs(a[i] - mid) - abs(a[i] - mid - 1);

        if (sum < mn)
            mn = sum, mnp = i;
    }

    solve(l, mid, L, mnp), solve(mid + 1, r, mnp + 1, R);
}

signed main() {
    scanf("%d", &n);

    for (int i = 1; i <= n; ++i)
        scanf("%d", a + i), a[i] -= i;

    solve(0, 1ll << 31, 1, n);
    ll ans = 0;

    for (int i = 1; i <= n; ++i)
        ans += abs(a[i] - b[i]);

    printf("%lld\n", ans);

    for (int i = 1; i <= n; ++i)
        printf("%d ", b[i] + i);

    return 0;
}

维护不可加贡献

策略为每次处理 \(solve(l, r, L, R)\) 时,先执行 \([l, mid]\) 的修改,将 \([L, R]\) 分为两部分后先不清空递归右半部分,再撤销 \([l, mid]\) 的修改递归左半部分。

CF603E Pastoral Oddities

给定 \(n\) 个点的无向图,依次加入 \(m\) 条无向带权边,每次加入后询问是否存在一个边集,满足每个点的度数均为奇数,若存在则还需最小化边集中的最大边权。

\(n \leq 10^5\)\(m \leq 3 \times 10^5\)

首先有一个结论:存在合法边集当且仅当所有连通块大小均为偶数。

必要性:连通块大小为奇数时若存在方案,则保留合法边集后此连通块度数之和为奇数,矛盾。

充分性:每个联通块内仅保留一棵生成树,然后从叶子开始,一个点与其父亲的连边保留当且仅当这个点与其所有儿子的连边数为偶数,那么就可以构造出来了。

先考虑无修改的情况:连通块大小均为偶数时,再添加一些边后依然满足条件,所以按边权从小到大排序后,有用的边一定是一个前缀,并且具有单调性,于是考虑整体二分。

solve(l, r, L, R) 表示 \([L, R]\) 的答案 \(\in [l, r]\) 。每次分治时钦定编号 \(< L\) 且权值 \(\leq l\) 的边一定被考虑,故需要保证每次分治时这些边已经加入并查集。

每次先加入权值 \(\leq mid\) 且编号 \(< L\) 的必须边,然后依次加入权值 \(\leq mid\) 且未考虑的边,记第一个合法的位置为 \(p\) ,则 \(ans_{p - 1} > mid\)\(ans_p \leq mid\) ,递归分治即可。

时间复杂度 \(O(m \log m \log n)\)

#include <bits/stdc++.h>
using namespace std;
const int N = 1e5 + 7, M = 3e5 + 7;

struct Edge {
    int u, v, w;
} e[M], g[M];

struct DSU {
    int fa[N], siz[N], sta[N];

    int top, odd;

    inline void prework(int n) {
        iota(fa + 1, fa + 1 + n, 1);
        fill(siz + 1, siz + 1 + n, 1);
        odd = n;
    }

    inline int find(int x) {
        while (x != fa[x])
            x = fa[x];

        return x;
    }

    inline void merge(int x, int y) {
        x = find(x), y = find(y);

        if (x == y)
            return;

        if (siz[x] < siz[y])
            swap(x, y);

        if ((siz[x] & 1) && (siz[y] & 1))
            odd -= 2;

        fa[y] = x, siz[x] += siz[y], sta[++top] = y;
    }

    inline void restore(int k) {
        while (top > k) {
            int y = sta[top--], x = fa[y];
            fa[y] = y, siz[x] -= siz[y];

            if ((siz[x] & 1) && (siz[y] & 1))
                odd += 2;
        }
    }
} dsu;

int id[M], ans[M];

int n, m;

void solve(int l, int r, int L, int R) {
    if (L > R)
        return;

    if (l == r) {
        fill(ans + L, ans + R + 1, e[id[l]].w);
        return;
    }

    int mid = (l + r) >> 1, oritop = dsu.top;

    for (int i = l; i <= mid; ++i)
        if (id[i] < L)
            dsu.merge(e[id[i]].u, e[id[i]].v);

    int p = R + 1, pretop = dsu.top;

    for (int i = L; i <= R; ++i) {
        if (e[i].w <= e[id[mid]].w)
            dsu.merge(e[i].u, e[i].v);

        if (!dsu.odd) {
            p = i;
            break;
        }
    }

    dsu.restore(pretop), solve(mid + 1, r, L, p - 1), dsu.restore(oritop);

    for (int i = L; i < p; ++i)
        if (e[i].w <= e[id[l]].w)
            dsu.merge(e[i].u, e[i].v);

    solve(l, mid, p, R), dsu.restore(oritop);
}

signed main() {
    scanf("%d%d", &n, &m);

    for (int i = 1; i <= m; ++i)
        scanf("%d%d%d", &e[i].u, &e[i].v, &e[i].w);

    iota(id + 1, id + 1 + m, 1);

    sort(id + 1, id + 1 + m, [](const int &a, const int &b) {
        return e[a].w < e[b].w;
    });

    e[m + 1].w = -1, id[m + 1] = m + 1;
    dsu.prework(n), solve(1, m + 1, 1, m);

    for (int i = 1; i <= m; ++i)
        printf("%d\n", ans[i]);

    return 0;
}

P3250 [HNOI2016] 网络

给定一棵树,\(m\) 次操作:

  • 向路径集合中加入路径 \(x \to y\) ,权值为 \(k\)
  • 向路径集合中删除第 \(x\) 条路径。
  • 求路径集合中所有不经过 \(u\) 的路径的权值最大值。

\(n \leq 10^5\)\(m \leq 2 \times 10^5\)

考虑整体二分,如果某个询问点被所有大于 \(mid\) 的路径所经过,那么答案 \(\leq mid\) ,否则答案 \(> mid\)

查询经过一个点的路径条数用树上差分即可,时间复杂度 \(O(n \log n \log V)\)

#include <bits/stdc++.h>
using namespace std;
const int N = 1e5 + 7, M = 2e5 + 7, LOGN = 17;

struct Graph {
    vector<int> e[N];
    
    inline void insert(int u, int v) {
        e[u].emplace_back(v);
    }
} G;

struct Node {
    int op, x, y, k, id;
} nd[M], ndl[M], ndr[M];

int fa[N][LOGN];
int dep[N], in[N], out[N], ans[M];

int n, m, cntq, dfstime;

template <class T = int>
inline T read() {
    char c = getchar();
    bool sign = (c == '-');
    
    while (c < '0' || c > '9')
        c = getchar(), sign |= (c == '-');
    
    T x = 0;
    
    while ('0' <= c && c <= '9')
        x = (x << 1) + (x << 3) + (c & 15), c = getchar();
    
    return sign ? (~x + 1) : x;
}

void dfs(int u, int f) {
    fa[u][0] = f, dep[u] = dep[f] + 1, in[u] = ++dfstime;

    for (int i = 1; i < LOGN; ++i)
        fa[u][i] = fa[fa[u][i - 1]][i - 1];

    for (int v : G.e[u])
        if (v != f)
            dfs(v, u);

    out[u] = dfstime;
}

inline int LCA(int x, int y) {
    if (dep[x] < dep[y])
        swap(x, y);

    for (int i = 0, h = dep[x] - dep[y]; h; ++i, h >>= 1)
        if (h & 1)
            x = fa[x][i];

    if (x == y)
        return x;

    for (int i = LOGN - 1; ~i; --i)
        if (fa[x][i] != fa[y][i])
            x = fa[x][i], y = fa[y][i];

    return fa[x][0];
}

namespace BIT {
int c[N];

inline void update(int x, int k) {
    for (; x <= n; x += x & -x)
        c[x] += k;
}

inline int ask(int x) {
    int res = 0;

    for (; x; x -= x & -x)
        res += c[x];

    return res;
}

inline int query(int l, int r) {
    return ask(r) - ask(l - 1);
}
} // namespace BIT

inline void update(int x, int y, int k) {
    BIT::update(in[x], k), BIT::update(in[y], k);
    int lca = LCA(x, y);
    BIT::update(in[lca], -k);

    if (fa[lca][0])
        BIT::update(in[fa[lca][0]], -k);
}

void solve(int l, int r, int L, int R) {
    if (L > R)
        return;

    if (l == r) {
        for (int i = L; i <= R; ++i)
            if (nd[i].op == 2)
                ans[nd[i].id] = l;

        return;
    }

    int mid = (l + r) >> 1, lp = 0, rp = 0;

    for (int i = L, sum = 0; i <= R; ++i)
        if (nd[i].op == 2) {
            if (BIT::query(in[nd[i].x], out[nd[i].x]) == sum)
                ndl[lp++] = nd[i];
            else
                ndr[rp++] = nd[i];
        } else {
            if (nd[i].k <= mid)
                ndl[lp++] = nd[i];
            else
                ndr[rp++] = nd[i], sum += nd[i].op, update(nd[i].x, nd[i].y, nd[i].op);
        }

    for (int i = L; i <= R; ++i)
        if (nd[i].op != 2 && nd[i].k > mid)
            update(nd[i].x, nd[i].y, -nd[i].op);
    
    memcpy(nd + L, ndl, sizeof(Node) * lp);
    memcpy(nd + L + lp, ndr, sizeof(Node) * rp);
    solve(l, mid, L, L + lp - 1), solve(mid + 1, r, L + lp, R);
}

signed main() {
    n = read(), m = read();

    for (int i = 1; i < n; ++i) {
        int u = read(), v = read();
        G.insert(u, v), G.insert(v, u);
    }

    dfs(1, 0);

    for (int i = 1; i <= m; ++i) {
        nd[i].op = read();

        if (!nd[i].op)
            nd[i].op = 1, nd[i].x = read(), nd[i].y = read(), nd[i].k = read();
        else if (nd[i].op == 1)
            nd[i] = nd[read()], nd[i].op = -1;
        else
            nd[i].x = read(), nd[i].id = ++cntq;
    }

    solve(-1, 1e9, 1, m);

    for (int i = 1; i <= cntq; ++i)
        printf("%d\n", ans[i]);

    return 0;
}

点分治

点分治适合处理大规模的树上路径信息统计问题。

考虑统计一个子树内的路径信息,将路径分为经过根节点的路径和不经过根节点的路径。

对于不经过根节点的路径,递归处理即可。

对于经过根节点的路径,将其视作从子树内一个点到根节点再到子树内另一个点,于是先求出所有点到根的路径信息,再整合即可。

直接这么做是 \(O(n^2)\) 的。但是可以发现,算完一个点为路径 LCA 的答案后,其各个子树互不影响。

考虑每次计算一个子树内的答案时,选取重心为根划分子树。由于每次划分子树大小至少缩减一半,因此时间复杂度降为 \(O(n \log n)\)

P3806 【模板】点分治1

给定一棵树,\(m\) 次询问树上距离为 \(k\) 的点对是否存在。

\(n \leq 10^4\)\(m \leq 100\)

用双指针整合信息即可。

#include <bits/stdc++.h>
using namespace std;
const int N = 1e4 + 7;

struct Graph {
    vector<pair<int, int> > e[N];
    
    inline void insert(const int u, const int v, const int w) {
        e[u].emplace_back(v, w);
    }
} G;

vector<int> vec;

int siz[N], mxsiz[N], dis[N], top[N], qry[N];
bool vis[N], ans[N];

int n, m, root;

void getroot(int u, int f, int Siz) {
    siz[u] = 1, mxsiz[u] = 0;

    for (auto it : G.e[u]) {
        int v = it.first;

        if (v == f || vis[v])
            continue;

        getroot(v, u, Siz), siz[u] += siz[v];
        mxsiz[u] = max(mxsiz[u], siz[v]);
    }

    mxsiz[u] = max(mxsiz[u], Siz - siz[u]);

    if (!root || mxsiz[u] < mxsiz[root])
        root = u;
}

int getsiz(int u, int f) {
    int siz = 1;

    for (auto it : G.e[u]) {
        int v = it.first;

        if (v != f && !vis[v])
            siz += getsiz(v, u);
    }

    return siz;
}

void getdis(int u, int f) {
    vec.emplace_back(u);

    for (auto it : G.e[u]) {
        int v = it.first, w = it.second;

        if (v != f && !vis[v])
            top[v] = top[u], dis[v] = dis[u] + w, getdis(v, u);
    }
}

inline void calc(int u) {
    vec.clear(), vec.emplace_back(u);
    top[u] = u, dis[u] = 0;

    for (auto it : G.e[u]) {
        int v = it.first, w = it.second;

        if (!vis[v])
            top[v] = v, dis[v] = w, getdis(v, u);
    }

    sort(vec.begin(), vec.end(), [](const int &x, const int &y) { return dis[x] < dis[y]; });

    for (int i = 1; i <= m; ++i) {
        if (ans[i])
            continue;

        for (auto itl = vec.begin(), itr = prev(vec.end()); itl != itr;) {
            if (dis[*itl] + dis[*itr] < qry[i])
                ++itl;
            else if (dis[*itl] + dis[*itr] > qry[i])
                --itr;
            else if (top[*itl] == top[*itr]) {
                if (dis[*itl] == dis[*next(itl)])
                    ++itl;
                else
                    --itr;
            } else {
                ans[i] = true;
                break;
            }
        }
    }
}

void solve(int u) {
    vis[u] = true, calc(u);

    for (auto it : G.e[u]) {
        int v = it.first;

        if (!vis[v])
            root = 0, getroot(v, 0, getsiz(v, u)), solve(root);
    }
}

signed main() {
    scanf("%d%d", &n, &m);

    for (int i = 1; i < n; ++i) {
        int u, v, w;
        scanf("%d%d%d", &u, &v, &w);
        G.insert(u, v, w), G.insert(v, u, w);
    }

    for (int i = 1; i <= m; ++i)
        scanf("%d", qry + i), ans[i] = !qry[i];

    root = 0, getroot(1, 0, n), solve(root);
    
    for (int i = 1; i <= m; ++i)
        puts(ans[i] ? "AYE" : "NAY");

    return 0;
}

P6626 [省选联考 2020 B 卷] 消息传递

给定一棵 \(n\) 个节点的树,\(m\) 次询问和一个点 \(x\) 距离为 \(k\) 的点的数量。

\(1 \leq n, m \leq 10^5\)

指定一个点为根,那么对于任意一个非根的节点 \(x\) ,与它距离为 \(k\) 的点无非会有两种情况:子树内或子树外。

对于在子树外的节点,记 \(t_i\) 表示与根距离为 \(i\) 的点的个数。如果一个子树外的点 \(y\)\(x\) 的距离为 \(k\) ,那么它们一定满足 \(dep_x + dep_y = k\) 。所以可以直接把对应的桶 \(t_{k - dep_x}\)计入答案。

子树内的结点分治处理即可,时间复杂度 \(O(n \log n)\)

#include <bits/stdc++.h>
using namespace std;
const int N = 1e5 + 7;

struct Graph {
    vector<int> e[N];
    
    inline void clear(int n) {
        for (int i = 1; i <= n; ++i)
            e[i].clear();
    }
    
    inline void insert(int u, int v) {
        e[u].emplace_back(v);
    }
} G;

vector<pair<int, int> > qry[N];
vector<pair<int, int> > tmp;

int siz[N], mxsiz[N], dep[N], buc[N], ans[N];
bool vis[N];

int n, m, root, mxdep;

inline void clear() {
    G.clear(n);
    memset(vis + 1, false, sizeof(bool) * n);
    memset(ans + 1, 0, sizeof(int) * n);
    
    for (int i = 1; i <= n; ++i)
        qry[i].clear();
}

void getroot(int u, int f, int Siz) {
    siz[u] = 1, mxsiz[u] = 0;
    
    for (int v : G.e[u])
        if (!vis[v] && v != f)
            getroot(v, u, Siz), siz[u] += siz[v], mxsiz[u] = max(mxsiz[u], siz[v]);
    
    mxsiz[u] = max(mxsiz[u], Siz - mxsiz[u]);
    
    if (!root || mxsiz[u] < mxsiz[root])
        root = u;
}

int getsiz(int u, int f) {
    int siz = 1;

    for (int v : G.e[u])
        if (!vis[v] && v != f)
            siz += getsiz(v, u);

    return siz;
}

void dfs(int u, int f) {
    ++buc[dep[u]], mxdep = max(mxdep, dep[u]);
    
    for (auto it : qry[u])
        if (it.first >= dep[u])
            tmp.emplace_back(it.first - dep[u], it.second);
    
    for (int v : G.e[u])
        if (v != f && !vis[v])
            dep[v] = dep[u] + 1, dfs(v, u);
}

inline void calc(int u) {
    tmp.clear(), mxdep = dep[u] = 0, dfs(u, 0);
    
    for (auto it : tmp)
        ans[it.second] += buc[it.first];
    
    memset(buc, 0, sizeof(int) * (mxdep + 1));
    
    for (int v : G.e[u])
        if (!vis[v]) {
            tmp.clear(), mxdep = dep[v] = 1, dfs(v, u);
            
            for (auto it : tmp)
                ans[it.second] -= buc[it.first];
            
            memset(buc, 0, sizeof(int) * (mxdep + 1));
        }
}

void solve(int u) {
    vis[u] = true, calc(u);
    
    for (int v : G.e[u])
        if (!vis[v])
            root = 0, getroot(v, 0, getsiz(v, u)), solve(root);
}

signed main() {
    int T;
    scanf("%d", &T);
    
    while (T--) {
        scanf("%d%d", &n, &m);
        clear();
        
        for (int i = 1; i < n; ++i) {
            int u, v;
            scanf("%d%d", &u, &v);
            G.insert(u, v), G.insert(v, u);
        }
        
        for (int i = 1; i <= m; ++i) {
            int x, k;
            scanf("%d%d", &x, &k);
            qry[x].emplace_back(k, i);
        }
        
        root = 0, getroot(1, 0, n), solve(root);
        
        for (int i = 1; i <= m; ++i)
            printf("%d\n", ans[i]);
    }
    
    return 0;
}

P2664 树上游戏

给出一棵树,每个点有颜色,对每个点 \(x\) 求所有点到 \(x\) 路径上颜色数量的和。

\(n \leq 10^5\)

看到树上路径,考虑点分治。设 \(cnt_i\) 表示当前统计 \(u\) 子树内所有到 \(u\) 的路径有多少条路径出现颜色 \(i\) 。这个是好处理的,只要到 \(u\) 的链上第一次出现颜色 \(i\) 时令 \(cnt_i \leftarrow cnt_i + siz_v\) 即可。

然后剩下的就是对于 \(v\) 子树外的贡献减去 \(v\) 子树内的贡献即可,容斥上应该有一些细节需要处理。

#include <bits/stdc++.h>
typedef long long ll;
using namespace std;
const int N = 1e5 + 7;

struct Graph {
    vector<int> e[N];
    
    inline void insert(int u, int v) {
        e[u].emplace_back(v);
    }
} G;

ll ans[N], Sum;
int col[N], siz[N], mxsiz[N], sta[N], chaincol[N], cnt[N];
bool vis[N];

int n, root, outsiz;

void getroot(int u, int f, const int Siz) {
    siz[u] = 1, mxsiz[u] = 0;

    for (int v : G.e[u])
        if (!vis[v] && v != f)
            getroot(v, u, Siz), siz[u] += siz[v], mxsiz[u] = max(mxsiz[u], siz[v]);

    mxsiz[u] = max(mxsiz[u], Siz - siz[u]);

    if (!root || mxsiz[u] < mxsiz[root])
        root = u;
}

int getsiz(int u, int f) {
    siz[u] = 1;

    for (int v : G.e[u])
        if (!vis[v] && v != f)
            siz[u] += getsiz(v, u);

    return siz[u];
}

void addcol(int u, int f) {
    if (!chaincol[col[u]])
        cnt[col[u]] += siz[u], Sum += siz[u];

    ++chaincol[col[u]];

    for (int v : G.e[u])
        if (!vis[v] && v != f)
            addcol(v, u);

    --chaincol[col[u]];
}

void delcol(int u, int f) {
    if (!chaincol[col[u]])
        cnt[col[u]] -= siz[u], Sum -= siz[u];
    
    ++chaincol[col[u]];

    for (int v : G.e[u])
        if (!vis[v] && v != f)
            delcol(v, u);

    --chaincol[col[u]];
}

void dfs(int u, int f, int num, ll sum) {
    if (!chaincol[col[u]])
        ++num, sum += cnt[col[u]];

    ++chaincol[col[u]], ans[u] += Sum - sum + num * outsiz;

    for (int v : G.e[u])
        if (!vis[v] && v != f)
            dfs(v, u, num, sum);

    --chaincol[col[u]];
}

void calc(int u) {
    Sum = 0, addcol(u, 0), ++chaincol[col[u]];

    for (int v : G.e[u])
        if (!vis[v]) {
            delcol(v, u), cnt[col[u]] -= siz[v], Sum -= siz[v];
            outsiz = siz[u] - siz[v], dfs(v, u, 0, 0);
            Sum += siz[v], cnt[col[u]] += siz[v], addcol(v, u);
        }

    ans[u] += Sum - cnt[col[u]] + siz[u];
    --chaincol[col[u]], delcol(u, 0);
}

void solve(int u) {
    getsiz(u, 0), vis[u] = true, calc(u);

    for (int v : G.e[u])
        if (!vis[v])
            root = 0, getroot(v, 0, siz[v]), solve(root);
}

signed main() {
    scanf("%d", &n);

    for (int i = 1; i <= n; ++i)
        scanf("%d", col + i);

    for (int i = 1; i < n; ++i) {
        int u, v;
        scanf("%d%d", &u, &v);
        G.insert(u, v), G.insert(v, u);
    }

    root = 0, getroot(1, 0, n), solve(root);

    for (int i = 1; i <= n; ++i)
        printf("%lld\n", ans[i]);

    return 0;
}

CF150E Freezing with Style

给定一棵树,边带边权。找到一条经过的边数 \(\in [L, R]\) 的树上路径,使得经过的边权中位数尽量大,输出这个中位数即可。

注意这里定义长度为 \(m\) 的序列的中位数为该序列第 \(\lceil \frac{m + 1}{2} \rceil\) 大的数的值。

\(n \leq 10^5\)

首先套路地二分答案,将问题转化成点权为 \(\pm 1\) ,求是否存在长度 \(\in [L, R]\) 且经过的点权和非负的路径。

考虑点分治,统计信息直接滑动窗口即可。

注意由于需要预处理,因此需要将子树大小排序后统计答案,这样才能保证统计单个根的复杂度为 \(O(\sum siz)\)

时间复杂度 \(O(n \log^2 n)\) ,提前存储分治树可以减小常数。

#include <bits/stdc++.h>
using namespace std;
const int N = 1e5 + 7;

struct Graph {
    vector<pair<int, int> > e[N];
    
    inline void insert(int u, int v, int w) {
        e[u].emplace_back(v, w);
    }
} G;

vector<tuple<int, int, int> > son[N];
vector<pair<int, int> > vec;
vector<int> son2[N], res;
deque<int> q;
tuple<int, int, int> ans;

int siz[N], mxsiz[N], dis[N];
bool vis[N];

int n, L, R, root, lim;

int getsiz(int u, int f) {
    siz[u] = 1;

    for (auto it : G.e[u]) {
        int v = it.first;

        if (!vis[v] && v != f)
            siz[u] += getsiz(v, u);
    }

    return siz[u];
}

void getroot(int u, int f, int Siz) {
    siz[u] = 1, mxsiz[u] = 0;

    for (auto it : G.e[u]) {
        int v = it.first;

        if (!vis[v] && v != f)
            getroot(v, u, Siz), siz[u] += siz[v], mxsiz[u] = max(mxsiz[u], siz[v]);
    }

    mxsiz[u] = max(mxsiz[u], Siz - siz[u]);

    if (!root || mxsiz[u] < mxsiz[root])
        root = u;
}

void build(int u) {
    vis[u] = true;

    for (auto it : G.e[u]) {
        int v = it.first, w = it.second;

        if (!vis[v])
            son[u].emplace_back(getsiz(v, u), v, w);
    }

    sort(son[u].begin(), son[u].end());

    for (auto it : son[u]) {
        int v = get<1>(it);
        root = 0, getroot(v, u, siz[v]), son2[u].emplace_back(root), build(root);
    }
}

void calc(int u, int f, int d) {
    if (d == res.size()) {
        while (!q.empty() && q.front() > R - d)
            q.pop_front();

        if (0 <= L - d && L - d < vec.size()) {
            while (!q.empty() && vec[q.back()] < vec[L - d])
                q.pop_back();

            q.emplace_back(L - d);
        }

        res.emplace_back(q.empty() ? -1 : q.front());
    }

    if (~res[d])
        ans = max(ans, make_tuple(dis[u] + vec[res[d]].first, vec[res[d]].second, u));

    for (auto it : G.e[u]) {
        int v = it.first, w = (it.second >= lim ? 1 : -1);
        
        if (!vis[v] && v != f)
            dis[v] = dis[u] + w, calc(v, u, d + 1);
    }
}

void addin(int u, int f, int d) {
    if (d < vec.size())
        vec[d] = max(vec[d], make_pair(dis[u], u));
    else
        vec.emplace_back(dis[u], u);

    for (auto it : G.e[u]) {
        int v = it.first, w = (it.second >= lim ? 1 : -1);
        
        if (!vis[v] && v != f)
            dis[v] = dis[u] + w, addin(v, u, d + 1);
    }
}

void solve(int u) {
    vis[u] = true, vec = {make_pair(0, u)};

    for (auto it : son[u]) {
        int v = get<1>(it), w = get<2>(it);
        q.clear();

        for (int i = min((int)vec.size() - 1, R); i >= L; --i) {
            while (!q.empty() && vec[q.back()] < vec[i])
                q.pop_back();

            q.emplace_back(i);
        }

        res = {q.empty() ? -1 : q.front()}, dis[v] = (w >= lim ? 1 : -1), calc(v, u, 1), addin(v, u, 1);
    }

    for (int v : son2[u])
        solve(v);
}

inline bool check(int mid) {
    ans = make_tuple(-1, -1, -1), lim = mid;
    memset(vis + 1, false, sizeof(bool) * n);
    return solve(root), ~get<0>(ans);
}

signed main() {
    scanf("%d%d%d", &n, &L, &R);
    vector<int> vec;

    for (int i = 1; i < n; ++i) {
        int u, v, w;
        scanf("%d%d%d", &u, &v, &w);
        G.insert(u, v, w), G.insert(v, u, w);
        vec.emplace_back(w);
    }

    sort(vec.begin(), vec.end()), vec.erase(unique(vec.begin(), vec.end()), vec.end());
    root = 0, getroot(1, 0, n), build(root);
    memset(vis + 1, false, sizeof(bool) * n);
    root = 0, getroot(1, 0, n);
    int l = 1, r = vec.size() - 1, res = 0;

    while (l <= r) {
        int mid = (l + r) >> 1;

        if (check(vec[mid]))
            res = mid, l = mid + 1;
        else
            r = mid - 1;
    }

    check(vec[res]);
    printf("%d %d", get<1>(ans), get<2>(ans));
    return 0;
}

建立分治结构

通常用于一类强制在线问题,若每次询问都分治,则时间复杂度无法接受。

考虑在分治结构上存储信息,从而能够快速查询。

P11685 [Algo Beat Contest 001 G] Great DS Homework

给定一个长 \(2n - 1\) 的表达式,形如 \(a_1 \space op_2 \space a_2 \space op_3 \space a_3 \cdots op_n \space a_n\),其中 \(a_i \in \{ 0, 1 \}\)\(op_i \in \{ \operatorname{or}, \operatorname{and}, \operatorname{xor} \}\) ,运算符不分优先级。

定义一个表达式的子表达式为它的一个满足端点均为数字的子区间。

\(m\) 次修改一个位置( \(a_i, op_i\) ),每次修改完求出所有子表达式的值的和。

\(n, m \leq 10^6\)

考虑基于中点的序列分治,每层考虑跨过中点的答案。

建立线段树结构,对于每个点维护:

  • \(nxt_i\)\(i\) 进入区间所得结果
  • \(L_i\)\(i\) 进入区间后该区间有多少前缀表达式值为 \(1\)
  • \(R_i\) :该区间有多少后缀表达式值为 \(i\)

pushup 不难合并信息,时间复杂度 \(O(n \log n)\)

#include <bits/stdc++.h>
typedef long long ll;
using namespace std;
const int N = 1e6 + 7;

struct Node {
    int k;
    char op;
} a[N];

char str[N << 1];

int n, m;

inline int calc(int a, char op, int b) {
    if (op == '|')
        return a | b;
    else if (op == '&')
        return a & b;
    else
        return a ^ b;
}

namespace SMT {
struct Node {
    int nxt[2], L[2], R[2];
    ll ans;

    inline friend Node operator + (const Node &a, const Node &b) {
        Node c;
        c.ans = a.ans + b.ans;

        for (int i = 0; i <= 1; ++i) {
            c.nxt[i] = b.nxt[a.nxt[i]];
            c.L[i] = a.L[i] + b.L[a.nxt[i]];
            c.R[i] = b.R[i] + (b.nxt[0] == i ? a.R[0] : 0) + (b.nxt[1] == i ? a.R[1] : 0);
            c.ans += 1ll * a.R[i] * b.L[i];
        }

        return c;
    }
} nd[N << 2];

inline int ls(int x) {
    return x << 1;
}

inline int rs(int x) {
    return x << 1 | 1;
}

void build(int x, int l, int r) {
    if (l == r) {
        for (int i = 0; i <= 1; ++i)
            nd[x].nxt[i] = nd[x].L[i] = calc(i, a[l].op, a[l].k), nd[x].R[i] = (a[l].k == i);

        nd[x].ans = a[l].k;
        return;
    }

    int mid = (l + r) >> 1;
    build(ls(x), l, mid), build(rs(x), mid + 1, r);
    nd[x] = nd[ls(x)] + nd[rs(x)];
}

void update(int x, int nl, int nr, int p) {
    if (l == r) {
        for (int i = 0; i <= 1; ++i)
            nd[x].nxt[i] = nd[x].L[i] = calc(i, a[l].op, a[l].k), nd[x].R[i] = (a[l].k == i);

        nd[x].ans = a[l].k;
        return;
    }

    int mid = (l + r) >> 1;

    if (p <= mid)
        update(ls(x), l, mid, p);
    else
        update(rs(x), mid + 1, r, p);

    nd[x] = nd[ls(x)] + nd[rs(x)];
}
} // namespace SMT

signed main() {
    ios::sync_with_stdio(0), cin.tie(0), cout.tie(0);
    cin >> n >> m >> (str + 1);
    str[0] = '|';

    for (int i = 0; i < n * 2; i += 2)
        a[i / 2 + 1] = (Node) {str[i + 1] & 15, str[i]};

    SMT::build(1, 1, n);

    while (m--) {
        int x;
        cin >> x;
        cin >> a[x].op >> a[x].k;
        SMT::update(1, 1, n, x);
        printf("%lld\n", SMT::nd[1].ans);
    }

    return 0;
}

点分树

点分树,即动态点分治,是通过更改原树形态使树的层数变为稳定 \(\log n\) 的一种重构树,常用于与树的形态无关的带修改问题。

性质:

  • 高度只有 \(\log n\) 级别。
  • 对于任意两点 \(u, v\) ,点分树上的 LCA 一定在原树中 \(u \to v\) 的路径上,即 \(\mathrm{dist}(u, v) = \mathrm{dist}(u, lca) + \mathrm{dist}(lca, v)\)

考虑通过点分治每次找重心的方式对原树进行重构。

将每次找到的重心与上一层的重心缔结父子关系,这样树高即为 \(\log n\)

如此,很多暴力都可以有正确的正确的时间复杂度。

一个比较常见的套路是这样的:

  • 进行一次点分治,求出每个点在点分树上的父节点。
  • 对于每个点,开一个数据结构 \(T_1\) 存储点分树子树的贡献,再开一个数据结构 \(T_2\) 存储点分树父亲的贡献用来容斥。
  • \(x\) 进行修改时,从 \(x\) 开始不断跳点分树的父亲一直到根,每次对经过的节点的 \(T_1, T_2\) 修改它的贡献。
  • \(x\) 进行查询时,从 \(x\) 开始不断跳点分树的父亲一直到根,每次把 \(T_1\) 的贡献添加进答案,把 \(T_2\) 的贡献从答案删去。

P6329 【模板】点分树 | 震波

给定一棵 \(n\) 个点的树,\(m\) 次操作:

  • 0 x k :求所有与 \(x\) 距离 \(\leq k\) 的所有点的点权和
  • 1 x y :修改 \(x\) 的点权为 \(y\)

\(n, m \leq 10^5\)

要求 \(\sum_{dis(x, y) \leq k} a_y\) 考虑枚举 \(x, y\) 在点分树上的 LCA \(z\) ,显然 \(z\) 的数量是 \(\log n\) 级别的,故答案即为:

\[\sum_{LCA(x, y) = z} a_y \times [dis(z, y) \leq k - dis(x, z)] \]

注意到满足 \(LCA(x, y) = z\) 的点即为 \(z\) 的所有子树去掉 \(x\) 方向子树的所有点,那么显然我们可以用 \(z\) 子树中满足条件的点权和减去 \(x\) 子树中满足条件的点权和。

对于每个 \(x\) 建一棵动态开点线段树,下标为 \(i\) 的位置维护 \(x\) 子树内所有 \(dis(x, y) = i\)\(a_y\) 和,那么统计答案时区间查询即可。

考虑对每个点再建一棵动态开点线段树,线段树上下标为 \(i\) 的位置维护 \(x\) 子树内到 \(fa_x\) 距离为 \(i\) 的点权和,再统计答案即可。

#include <bits/stdc++.h>
using namespace std;
const int N = 1e5 + 7, LOGN = 17, S = N << 5;

struct Graph {
    vector<int> e[N];
    
    inline void insert(int u, int v) {
        e[u].emplace_back(v);
    }
} G;

int fa[N][LOGN];
int a[N], dep[N], siz[N], mxsiz[N], nfa[N];
bool vis[N];

int n, m, root;

void dfs(int u, int f) {
    fa[u][0] = f, dep[u] = dep[f] + 1;

    for (int i = 1; i < LOGN; ++i)
        fa[u][i] = fa[fa[u][i - 1]][i - 1];

    for (int v : G.e[u])
        if (v != f)
            dfs(v, u);
}

inline int LCA(int x, int y) {
    if (dep[x] < dep[y])
        swap(x, y);

    for (int i = 0, h = dep[x] - dep[y]; h; ++i, h >>= 1)
        if (h & 1)
            x = fa[x][i];

    if (x == y)
        return x;

    for (int i = LOGN - 1; ~i; --i)
        if (fa[x][i] != fa[y][i])
            x = fa[x][i], y = fa[y][i];

    return fa[x][0];
}

inline int dist(int x, int y) {
    int lca = LCA(x, y);
    return dep[x] + dep[y] - dep[lca] * 2;
}

int getsiz(int u, int f) {
    int siz = 1;

    for (int v : G.e[u])
        if (!vis[v] && v != f)
            siz += getsiz(v, u);

    return siz;
}

void getroot(int u, int f, const int Siz) {
    siz[u] = 1, mxsiz[u] = 0;

    for (int v : G.e[u])
        if (!vis[v] && v != f)
            getroot(v, u, Siz), siz[u] += siz[v], mxsiz[u] = max(mxsiz[u], siz[v]);

    mxsiz[u] = max(mxsiz[u], Siz - siz[u]);

    if (!root || mxsiz[u] < mxsiz[root])
        root = u;
}

void build(int u) {
    vis[u] = true;

    for (int v : G.e[u])
        if (!vis[v])
            root = 0, getroot(v, u, getsiz(v, u)), nfa[root] = u, build(root);
}

struct SMT {
    int lc[S], rc[S], s[S];
    int rt[N];

    int tot;

    void update(int &x, int nl, int nr, int pos, int k) {
        if (!x)
            x = ++tot;

        s[x] += k;

        if (nl == nr)
            return;

        int mid = (nl + nr) >> 1;

        if (pos <= mid)
            update(lc[x], nl, mid, pos, k);
        else
            update(rc[x], mid + 1, nr, pos, k);
    }

    int query(int x, int nl, int nr, int pos) {
        if (!x)
            return 0;

        if (nl == nr)
            return s[x];

        int mid = (nl + nr) >> 1;

        return pos <= mid ? query(lc[x], nl, mid, pos) : s[lc[x]] + query(rc[x], mid + 1, nr, pos);
    }
} A, B;

inline void update(int x, int k) {
    for (int cur = x; cur; cur = nfa[cur]) {
        A.update(A.rt[cur], 0, n - 1, dist(cur, x), k);

        if (nfa[cur])
            B.update(B.rt[cur], 0, n - 1, dist(nfa[cur], x), k);
    }
}

inline int query(int x, int k) {
    int ans = 0;

    for (int cur = x, pre = 0; cur; pre = cur, cur = nfa[cur]) {
        if (k - dist(cur, x) < 0)
            continue;

        ans += A.query(A.rt[cur], 0, n - 1, k - dist(cur, x));

        if (pre)
            ans -= B.query(B.rt[pre], 0, n - 1, k - dist(cur, x));
    }

    return ans;
}

signed main() {
    scanf("%d%d", &n, &m);

    for (int i = 1; i <= n; ++i)
        scanf("%d", a + i);

    for (int i = 1; i < n; ++i) {
        int u, v;
        scanf("%d%d", &u, &v);
        G.insert(u, v), G.insert(v, u);
    }

    dfs(1, 0), root = 0, getroot(1, 0, n), build(root);

    for (int i = 1; i <= n; ++i)
        update(i, a[i]);

    int lstans = 0;

    while (m--) {
        int op, x, k;
        scanf("%d%d%d", &op, &x, &k);
        x ^= lstans, k ^= lstans;

        if (op)
            update(x, k - a[x]), a[x] = k;
        else
            printf("%d\n", lstans = query(x, k));
    }

    return 0;
}

P3241 [HNOI2015] 开店

给定一棵 \(n\) 个点的树,\(q\) 次询问点权在 \([l, r]\) 内的所有点到某个点 \(u\) 的距离之和,强制在线。

\(n \leq 1.5 \times 10^5, q \leq 2 \times 10^5\)

首先考虑没有 \([l, r]\) 限制时的做法,记:

\[f_1(x) = \sum_{y \in subtree(x)} dis(x, y) \\ f_2(x) = \sum_{y \in subtree(x)} dis(x, fa_y) \\ g(x) = \sum_{y \in subtree(x)} 1 \]

询问时先令答案为 \(f_1(x)\) ,之后不断在点分树上跳父亲,记当前点为 \(x\) ,则答案要加上:

\[f_1(nfa_x) - f_2(x) + (g(nfa_x) - g(x)) \times dist(x, nfa_x) \]

现在有了 \([l, r]\) 的限制,只要加上一维用 vector 存储,做一个前缀和,每次二分即可。

#include <bits/stdc++.h>
typedef long long ll;
using namespace std;
const ll inf = 1e18;
const int N = 1.5e5 + 7, LOGN = 19;

struct Graph {
    vector<pair<int, int> > e[N];

    inline void insert(int u, int v, int w) {
        e[u].emplace_back(v, w);
    }
} G;

vector<pair<int, ll> > vec[2][N];

ll dis[N];
int fa[N][LOGN];
int a[N], dep[N], siz[N], mxsiz[N], nfa[N];
bool vis[N];

int n, q, A, root;

void dfs(int u, int f) {
    fa[u][0] = f, dep[u] = dep[f] + 1;

    for (int i = 1; i < LOGN; ++i)
        fa[u][i] = fa[fa[u][i - 1]][i - 1];

    for (auto it : G.e[u]) {
        int v = it.first, w = it.second;

        if (v != f)
            dis[v] = dis[u] + w, dfs(v, u);
    }
}

inline int LCA(int x, int y) {
    if (dep[x] < dep[y])
        swap(x, y);

    for (int i = 0, h = dep[x] - dep[y]; h; ++i, h >>= 1)
        if (h & 1)
            x = fa[x][i];

    if (x == y)
        return x;

    for (int i = LOGN - 1; ~i; --i)
        if (fa[x][i] != fa[y][i])
            x = fa[x][i], y = fa[y][i];

    return fa[x][0];
}

inline ll dist(int x, int y) {
    int lca = LCA(x, y);
    return dis[x] + dis[y] - dis[lca] * 2;
}

int getsiz(int u, int f) {
    int siz = 1;

    for (auto it : G.e[u]) {
        int v = it.first;

        if (!vis[v] && v != f)
            siz += getsiz(v, u);
    }

    return siz;
}

void getroot(int u, int f, const int Siz) {
    siz[u] = 1, mxsiz[u] = 0;

    for (auto it : G.e[u]) {
        int v = it.first;

        if (!vis[v] && v != f)
            getroot(v, u, Siz), siz[u] += siz[v], mxsiz[u] = max(mxsiz[u], siz[v]);
    }

    mxsiz[u] = max(mxsiz[u], Siz - siz[u]);

    if (!root || mxsiz[u] < mxsiz[root])
        root = u;
}

void build(int u) {
    vis[u] = true;

    for (auto it : G.e[u]) {
        int v = it.first;

        if (!vis[v])
            root = 0, getroot(v, u, getsiz(v, u)), nfa[root] = u, build(root);
    }
}

inline ll query(int op, int x, int l, int r, ll &siz) {
    auto pl = lower_bound(vec[op][x].begin(), vec[op][x].end(), make_pair(l, 0ll)),
        pr = upper_bound(vec[op][x].begin(), vec[op][x].end(), make_pair(r, inf));
    siz = pr - pl;
    ll res = 0;

    if (pr != vec[op][x].begin())
        res += prev(pr)->second;

    if (pl != vec[op][x].begin())
        res -= prev(pl)->second;

    return res;
}

signed main() {
    scanf("%d%d%d", &n, &q, &A);

    for (int i = 1; i <= n; ++i)
        scanf("%d", a + i);

    for (int i = 1; i < n; ++i) {
        int u, v, w;
        scanf("%d%d%d", &u, &v, &w);
        G.insert(u, v, w), G.insert(v, u, w);
    }

    dfs(1, 0), root = 0, getroot(1, 0, n), build(root);

    for (int i = 1; i <= n; ++i)
        for (int cur = i; cur; cur = nfa[cur]) {
            vec[0][cur].emplace_back(a[i], dist(i, cur));

            if (nfa[cur])
                vec[1][cur].emplace_back(a[i], dist(i, nfa[cur]));
        }

    for (int i = 0; i <= 1; ++i)
        for (int j = 1; j <= n; ++j) {
            sort(vec[i][j].begin(), vec[i][j].end());

            for (int k = 1; k < vec[i][j].size(); ++k)
                vec[i][j][k].second += vec[i][j][k - 1].second;
        }

    ll lstans = 0;

    while (q--) {
        int x;
        ll l, r;
        scanf("%d%lld%lld", &x, &l, &r);
        l = (l + lstans) % A, r = (r + lstans) % A;

        if (l > r)
            swap(l, r);

        ll sizx, sizf;
        lstans = query(0, x, l, r, sizx);

        for (int cur = x; nfa[cur]; cur = nfa[cur]) {
            lstans += query(0, nfa[cur], l, r, sizf) - query(1, cur, l, r, sizx);
            lstans += dist(x, nfa[cur]) * (sizf - sizx);
        }
        
        printf("%lld\n", lstans);
    }

    return 0;
}

P3345 [ZJOI2015] 幻想乡战略游戏

维护一颗带点权、边权树(树上点的度数不超过 \(20\) )。现有若干次修改点权的操作,每次操作结束后您需要选出一个核心点 \(x\) 使得 \(\sum_{i = 1}^n dist(x, i) \times a_i\) 最小,求其最小值。

\(n, m \leq 10^5\)

不难发现 \(x\) 就是带权重心。因为最优决策儿子如果存在那么只会存在一个,又因为树上点的度数不超过 \(20\) ,于是可以在点分树上暴力跳儿子,时间复杂度 \(O(n \log^2 n \times d)\)

#include <bits/stdc++.h>
typedef long long ll;
using namespace std;
const int N = 1e5 + 7, LOGN = 17;

struct Graph1 {
    vector<pair<int, int> > e[N];
    
    inline void insert(int u, int v, int w) {
        e[u].emplace_back(v, w);
    }
} G;

struct Graph2 {
    vector<int> e[N];
    
    inline void insert(int u, int v) {
        e[u].emplace_back(v);
    }
} nG;

ll dis[N], sum[N], s1[N], s2[N];
int fa[N][LOGN];
int dep[N], siz[N], mxsiz[N], nfa[N], ori[N];
bool vis[N]; 

int n, m, root;

void dfs1(int u, int f) {
    fa[u][0] = f, dep[u] = dep[f] + 1;

    for (int i = 1; i < LOGN; ++i)
        fa[u][i] = fa[fa[u][i - 1]][i - 1];

    for (auto it : G.e[u]) {
        int v = it.first, w = it.second;

        if (v != f)
            dis[v] = dis[u] + w, dfs1(v, u);
    }
}

inline int LCA(int x, int y) {
    if (dep[x] < dep[y])
        swap(x, y);

    for (int i = 0, h = dep[x] - dep[y]; h; ++i, h >>= 1)
        if (h & 1)
            x = fa[x][i];

    if (x == y)
        return x;

    for (int i = LOGN - 1; ~i; --i)
        if (fa[x][i] != fa[y][i])
            x = fa[x][i], y = fa[y][i];

    return fa[x][0];
}

inline ll dist(int x, int y) {
    int lca = LCA(x, y);
    return dis[x] + dis[y] - dis[lca] * 2;
}

int getsiz(int u, int f) {
    int siz = 1;

    for (auto it : G.e[u]) {
        int v = it.first;

        if (!vis[v] && v != f)
            siz += getsiz(v, u);
    }

    return siz;
}

void getroot(int u, int f, const int Siz) {
    siz[u] = 1, mxsiz[u] = 0;

    for (auto it : G.e[u]) {
        int v = it.first;

        if (!vis[v] && v != f)
            getroot(v, u, Siz), siz[u] += siz[v], mxsiz[u] = max(mxsiz[u], siz[v]);
    }

    mxsiz[u] = max(mxsiz[u], Siz - siz[u]);

    if (!root || mxsiz[u] < mxsiz[root])
        root = u;
}

void build(int u) {
    vis[u] = true;

    for (auto it : G.e[u]) {
        int v = it.first;

        if (vis[v])
            continue;

        root = 0, getroot(v, u, getsiz(v, u)), ori[root] = v;
        nG.insert(nfa[root] = u, root), build(root);
    }
}

inline void update(int x, int k) {
    for (int cur = x; cur; cur = nfa[cur]) {
        sum[cur] += k, s1[cur] += dist(cur, x) * k;

        if (nfa[cur])
            s2[cur] += dist(nfa[cur], x) * k;
    }
}

inline ll query(int x) {
    ll res = 0;

    for (int cur = x, pre = 0; cur; pre = cur, cur = nfa[cur]) {
        res += s1[cur];

        if (pre)
            res += dist(cur, x) * (sum[cur] - sum[pre]) - s2[pre];
    }

    return res;
}

ll dfs2(int u) {
    ll res = query(u);

    for (int v : nG.e[u])
        if (query(ori[v]) < res)
            return dfs2(v);

    return res;
}

signed main() {
    scanf("%d%d", &n, &m);

    for (int i = 1; i < n; ++i) {
        int u, v, w;
        scanf("%d%d%d", &u, &v, &w);
        G.insert(u, v, w), G.insert(v, u, w);
    }

    dfs1(1, 0), root = 0, getroot(1, 0, n);
    int rt = root;
    build(root);

    while (m--) {
        int x, k;
        scanf("%d%d", &x, &k);
        update(x, k);
        printf("%lld\n", dfs2(rt));
    }

    return 0;
}

P5311 [Ynoi2011] 成都七中

给定一颗树,树上每个节点有一种颜色。\(m\) 次查询,每次查询给出 \(l, r, x\) ,求保留树上编号在 \([l, r]\) 内的点,\(x\) 所在联通块中颜色种类数。

\(n, m \leq 10^5\)

先建出点分树,对于一次查询,在点分树上 \(x\) 的祖先中找到深度最小的点 \(pa\) 且满足 \(x\) 只经过编号 \([l,r]\) 内的点在原树上能到达 \(pa\)

记每个点 \(i\) 到点分树祖先的路径上所经过的节点编号最小/大值为 \(d_{min}(i,j)\)\(d_{max}(i,j)\) ,则求 \(pa\) 直接暴力跳即可。

可以发现 \(x\) 只经过编号 \([l,r]\) 内的点所在的连通块被完全包含在了 \(pa\) 在点分树上的子树中。把本次询问放到 \(pa\) 节点处,最后再统一离线处理。

枚举虚树上的点 \(rt\),处理该节点处的询问时,对于任意一个 \((l,r,x)\),满足 \(l\leqslant d_{min}(i,rt)\)\(d_{max}(i,rt)\leqslant r\)\(i\) 即为与 \(x\) 在同一连通块内的点,不难离线扫描线处理。

时空复杂度 \(O(n \log^2 n + m \log n)\)

#include <bits/stdc++.h>
using namespace std;
const int N = 1e5 + 7;

struct Graph {
    vector<int> e[N];
    
    inline void insert(int u, int v) {
        e[u].emplace_back(v);
    }
} G;

struct Node {
    int l, r, id;

    inline bool operator < (const Node &rhs) const {
        return l == rhs.l ? id > rhs.id : l > rhs.l;
    }
};

vector<Node> nfa[N], nd[N];

int col[N], siz[N], mxsiz[N], lst[N], ans[N];
bool vis[N];

int n, m, root;

void getroot(int u, int f, int Siz) {
    siz[u] = 1, mxsiz[u] = 0;

    for (int v : G.e[u])
        if (!vis[v] && v != f)
            getroot(v, u, Siz), siz[u] += siz[v], mxsiz[u] = max(mxsiz[u], siz[v]);

    mxsiz[u] = max(mxsiz[u], Siz - siz[u]);

    if (!root || mxsiz[u] < mxsiz[root])
        root = u;
}

void dfs(int u, int f, int mx, int mn, int rt) {
    siz[u] = 1;
    nfa[u].emplace_back((Node){mn, mx, rt});
    nd[rt].emplace_back((Node){mn, mx, col[u]});

    for (int v : G.e[u])
        if (!vis[v] && v != f)
            dfs(v, u, max(mx, v), min(mn, v), rt), siz[u] += siz[v];
}

void build(int u) {
    vis[u] = true, dfs(u, 0, u, u, u);

    for (int v : G.e[u])
        if (!vis[v])
            root = 0, getroot(v, u, siz[v]), build(root);
}

namespace BIT {
int c[N];

inline void update(int x, int k) {
    for (; x <= n; x += x & -x)
        c[x] += k;
}

inline int query(int x) {
    int res = 0;

    for (; x; x -= x & -x)
        res += c[x];

    return res;
}
} // namespace BIT

signed main() {
    scanf("%d%d", &n, &m);
    vector<int> vec;

    for (int i = 1; i <= n; ++i)
        scanf("%d", col + i), vec.emplace_back(col[i]);

    sort(vec.begin(), vec.end()), vec.erase(unique(vec.begin(), vec.end()), vec.end());

    for (int i = 1; i <= n; ++i)
        col[i] = lower_bound(vec.begin(), vec.end(), col[i]) - vec.begin() + 1;

    for (int i = 1; i < n; ++i) {
        int u, v;
        scanf("%d%d", &u, &v);
        G.insert(u, v), G.insert(v, u);
    }

    root = 0, getroot(1, 0, n), build(root);

    for (int i = 1; i <= m; ++i) {
        int l, r, x;
        scanf("%d%d%d", &l, &r, &x);

        for (Node it : nfa[x])
            if (l <= it.l && it.r <= r) {
                nd[it.id].emplace_back((Node){l, r, -i});
                break;
            }
    }

    for (int i = 1; i <= n; ++i) {
        sort(nd[i].begin(), nd[i].end());

        for (Node it : nd[i]) {
            if (it.id < 0)
                ans[-it.id] = BIT::query(it.r);
            else if (!lst[it.id] || it.r < lst[it.id]) {
                if (lst[it.id])
                    BIT::update(lst[it.id], -1);

                BIT::update(lst[it.id] = it.r, 1);
            }
        }
        
        for (Node it : nd[i])
            if (it.id > 0 && lst[it.id])
                BIT::update(lst[it.id], -1), lst[it.id] = 0;
    }

    for (int i = 1; i <= m; ++i)
        printf("%d\n", ans[i]);

    return 0;
}

P2664 树上游戏

给出一棵树,每个点有一个颜色,对于所有 \(i\) ,求 \(\sum_{j = 1}^n s(i, j)\) ,其中 \(s(i, j)\) 表示 \(i \to j\) 路径上的颜色数量。

\(n \leq 10^5\)

树上路径相关问题考虑点分治。设 \(cnt_i\) 表示当前统计 \(u\) 子树内所有到 \(u\) 的路径有多少条路径出现颜色 \(i\) 。这个是好处理的,只要到 \(u\) 的链上第一次出现颜色 \(i\) 时令 \(cnt_i \leftarrow cnt_i + siz_v\) 即可。

然后剩下的就是对于 \(v\) 子树外的贡献减去 \(v\) 子树内的贡献即可,容斥上有一些细节需要处理。

#include <bits/stdc++.h>
typedef long long ll;
using namespace std;
const int N = 1e5 + 7;

struct Graph {
    vector<int> e[N];
    
    inline void insert(int u, int v) {
        e[u].emplace_back(v);
    }
} G;

ll ans[N], Sum;
int col[N], siz[N], mxsiz[N], sta[N], chaincol[N], cnt[N];
bool vis[N];

int n, root, outsiz;

void getroot(int u, int f, const int Siz) {
    siz[u] = 1, mxsiz[u] = 0;

    for (int v : G.e[u])
        if (!vis[v] && v != f)
            getroot(v, u, Siz), siz[u] += siz[v], mxsiz[u] = max(mxsiz[u], siz[v]);

    mxsiz[u] = max(mxsiz[u], Siz - siz[u]);

    if (!root || mxsiz[u] < mxsiz[root])
        root = u;
}

int getsiz(int u, int f) {
    siz[u] = 1;

    for (int v : G.e[u])
        if (!vis[v] && v != f)
            siz[u] += getsiz(v, u);

    return siz[u];
}

void addcol(int u, int f) {
    if (!chaincol[col[u]])
        cnt[col[u]] += siz[u], Sum += siz[u];

    ++chaincol[col[u]];

    for (int v : G.e[u])
        if (!vis[v] && v != f)
            addcol(v, u);

    --chaincol[col[u]];
}

void delcol(int u, int f) {
    if (!chaincol[col[u]])
        cnt[col[u]] -= siz[u], Sum -= siz[u];
    
    ++chaincol[col[u]];

    for (int v : G.e[u])
        if (!vis[v] && v != f)
            delcol(v, u);

    --chaincol[col[u]];
}

void dfs(int u, int f, int num, ll sum) {
    if (!chaincol[col[u]])
        ++num, sum += cnt[col[u]];

    ++chaincol[col[u]], ans[u] += Sum - sum + num * outsiz;

    for (int v : G.e[u])
        if (!vis[v] && v != f)
            dfs(v, u, num, sum);

    --chaincol[col[u]];
}

void calc(int u) {
    Sum = 0, addcol(u, 0), ++chaincol[col[u]];

    for (int v : G.e[u])
        if (!vis[v]) {
            delcol(v, u), cnt[col[u]] -= siz[v], Sum -= siz[v];
            outsiz = siz[u] - siz[v], dfs(v, u, 0, 0);
            Sum += siz[v], cnt[col[u]] += siz[v], addcol(v, u);
        }

    ans[u] += Sum - cnt[col[u]] + siz[u];
    --chaincol[col[u]], delcol(u, 0);
}

void solve(int u) {
    getsiz(u, 0), vis[u] = true, calc(u);

    for (int v : G.e[u])
        if (!vis[v])
            root = 0, getroot(v, 0, siz[v]), solve(root);
}

signed main() {
    scanf("%d", &n);

    for (int i = 1; i <= n; ++i)
        scanf("%d", col + i);

    for (int i = 1; i < n; ++i) {
        int u, v;
        scanf("%d%d", &u, &v);
        G.insert(u, v), G.insert(v, u);
    }

    root = 0, getroot(1, 0, n), solve(root);

    for (int i = 1; i <= n; ++i)
        printf("%lld\n", ans[i]);

    return 0;
}

P5912 [POI2004] JAS / [AGC009D] Uninity

求深度最浅的点分树的深度。

\(n \leq 10^5\)

\(d_i\) 表示 \(i\) 在点分树上的高度,一组 \(d_i\) 合法当且仅当对于任意 \(d_u = d_v (u \neq v)\) 均满足 \(u \to v\) 路径上存在一点 \(w\) 满足 \(d_w > d_u\)

考虑从下到上贪心,给每个点确定不和下方冲突的尽可能小的标号,即可找到最大标号的最小值。

\(d_u\) 表示 \(u\) 的标号,\(f_u\) 表示 \(u\) 子树内的最大标号,\(g_u\) 表示 \(u\) 子树内目前不满足条件的 \(d\) (即对于 \(u\) 子树内的点 \(v\)\(fa_u \to v\) 路径上不存在 \(w\) 满足 \(d_w > d_v\)\(d_v\) 的集合)。

\(u\) 为叶子时,有 \(f_u = d_u = 0, g_u = \{ 1 \}\)

否则记 \(S_u\) 表示 \(u\) 子树内目前不满足条件 \(d\) 的集合,\(T_u\) 表示 \(u\) 存在不同子树具有相同 \(d\)\(d\) 的集合,即:

\[\begin{aligned} S_u &= \bigcup_{v \in son(u)} g_v \\ T_u &= \bigcup_{v, w \in son(u)} g_v \cap g_w \end{aligned} \]

考虑合法的 \(d_u\) 需要满足的条件,即 \(d_u > \max (T_u)\)\(d_u \notin S_u\) 。从而得到:

\[f_u = \max(\max_{v \in son(u)} f_v, d_u) \]

此时子树内 \(d_v < d_u\)\(d_v\) 均合法,因此:

\[g_u = \{ d_u \} \cup \{ x \mid x \in S_u, x \geq d_u \} \]

由于 \(d \leq \log n\) ,使用一些位运算技巧可以做到线性。

#include <bits/stdc++.h>
using namespace std;
const int N = 1e5 + 7;

struct Graph {
    vector<int> e[N];
    
    inline void insert(int u, int v) {
        e[u].emplace_back(v);
    }
} G;

int f[N], g[N], d[N];

int n;

void dfs(int u, int fa) {
    if (!(G.e[u].size() - (fa ? 1 : 0))) {
        f[u] = d[u] = 0, g[u] = 1;
        return;
    }

    int s = 0, t = 0;

    for (int v : G.e[u])
        if (v != fa)
            dfs(v, u), f[u] = max(f[u], f[v]), t |= s & g[v], s |= g[v];

    f[u] = max(f[u], d[u] = __builtin_ctz(~((1 << (__lg(t) + 1)) - 1) & ~s));
    g[u] = (s & ~((1 << d[u]) - 1)) | (1 << d[u]);
}

signed main() {
    scanf("%d", &n);

    for (int i = 1; i < n; ++i) {
        int u, v;
        scanf("%d%d", &u, &v);
        G.insert(u, v), G.insert(v, u);
    }

    dfs(1, 0);
    printf("%d", f[1]);
    return 0;
}

CF776F Sherlock's bet to Moriarty

给定一个正 \(n\) 边形,用 \(m\) 条对角线将其分开,保证对角线两两无交,记划分出的区域为 \(R_{1 \sim m + 1}\)

令某区域的端点为 \(a_{1 \sim |R_i|}\) ,定义一个区域的权值为 \(f(R_i) = \sum_{j = 1}^{|R_i|} 2^{a_j}\) 。将所有区域按重要度从小到大排序并编号,即 \(f(R_1) < f(R_2) < \cdots < f(R_{m + 1})\)

现在需要给每个区域染上 \(1 \sim 20\) 之间的数字作为颜色,满足:对于任意两个颜色相同的区域 \(R_i, R_j (i \neq j)\) ,均满足它们之间的任何一条简单路径上都均存在至少一个区域 \(R_k\) 满足其颜色小于 \(i, j\) 的颜色。

构造一组方案。

\(n \leq 10^5\)\(m \leq n - 3\)

考虑将整个多边形剖分的结构建树,问题转化为给整棵树染色,使得任意两个同色点之间存在颜色更小的点。

先考虑建树,设所有对角线为 \((a_i, b_i)\) ,其中钦定 \(a_i < b_i\) 。由于对角线两两不交,因此考虑类似括号序列的处理方法。

\(2m\) 个点排序,有相同点时钦定 \(b < a\) 。设当前点为 \(i\)

  • \(i\) 是某个 \(a_j\) ,则将其入栈。
  • 否则 \(i\) 是某个 \(b_j\) ,且当前栈顶即为其对应的 \(a_j\) ,那么所有在 \([a_j, b_j]\) 范围内的点都会被拿出来成为一个区域删点可以用双向链表或 set 维护。

找到所有区域后按字典序排序,后枚举每一条对角线,它是某两个区域的公共边,找到这两个区域并连边,即可得到最终的剖分树。

再考虑给整棵树染色,直接套用 P5912 [POI2004] JAS 的做法即可。

时间复杂度 \(O(n \log n)\)

#include <bits/stdc++.h>
using namespace std;
const int N = 1e5 + 7;

struct Graph {
    vector<int> e[N];
    
    inline void insert(int u, int v) {
        e[u].emplace_back(v);
    }
} G;

pair<int, int> p[N << 1];
vector<int> a[N];

int sta[N], f[N], g[N], d[N];

int n, m, tot;

void dfs(int u, int fa) {
    if (!(G.e[u].size() - (fa ? 1 : 0))) {
        f[u] = d[u] = 0, g[u] = 1;
        return;
    }

    int s = 0, t = 0;

    for (int v : G.e[u])
        if (v != fa)
            dfs(v, u), f[u] = max(f[u], f[v]), t |= s & g[v], s |= g[v];

    f[u] = max(f[u], d[u] = __builtin_ctz(~((1 << (__lg(t) + 1)) - 1) & ~s));
    g[u] = (s & ~((1 << d[u]) - 1)) | (1 << d[u]);
}

signed main() {
    scanf("%d%d", &n, &m);

    for (int i = 1; i <= m; ++i) {
        int x, y;
        scanf("%d%d", &x, &y);

        if (x > y)
            swap(x, y);

        p[i] = make_pair(x, 1), p[m + i] = make_pair(y, 0);
    }

    sort(p + 1, p + m * 2 + 1);
    set<int> st;

    for (int i = 1; i <= n; ++i)
        st.emplace(i);

    for (int i = 1, top = 0; i <= m * 2; ++i) {
        if (p[i].second) {
            sta[++top] = p[i].first;
            continue;
        }

        int l = sta[top--], r = p[i].first;
        a[++tot] = {l};

        for (;;) {
            auto it = st.upper_bound(l);

            if (*it == r)
                break;

            a[tot].emplace_back(*it), st.erase(it);
        }

        a[tot].emplace_back(r), reverse(a[tot].begin(), a[tot].end());
    }

    a[++m] = vector<int>(st.begin(), st.end()), reverse(a[m].begin(), a[m].end());
    sort(a + 1, a + m + 1);
    map<pair<int, int>, int> mp;

    for (int i = 1; i <= m; ++i)
        for (int j = 0; j < a[i].size(); ++j) {
            int x = a[i][j], y = (j ? a[i][j - 1] : a[i].back());

            if (x > y)
                swap(x, y);

            if (x == 1 && y == n && x + 1 == y)
                continue;

            auto it = mp.find(make_pair(x, y));

            if (it == mp.end())
                mp[make_pair(x, y)] = i;
            else
                G.insert(it->second, i), G.insert(i, it->second);
        }

    dfs(1, 0);

    for (int i = 1; i <= m; ++i)
        printf("%d ", 20 - d[i]);

    return 0;
}

QOJ7855. 不跳棋

给定一棵 \(n\) 个点的树,初始每个点上都有一个棋子。

\(n - 2\) 次操作,每次拿走树上的一个棋子,每次操作后求最小的两个不同棋子的距离以及最小距离的点对数量。

\(n \leq 5 \times 10^5\) ,强制在线

考虑点分治,对每个分治中心维护经过它的最短路径与达到最小路径的路径数量,然后在一个桶上统一计算答案。

考虑直接求出距离分治中心最近的两个点形成路径,并统计路径条数。如果两个点来自同一 棵子树,全局的答案一定会比这个长度小,所以不合法的路径不会影响答案。

对每一个分治中心开一个大小为连通块内最大深度的桶并维护两个指针即可统计答案,时间复杂度 \(O(n \log n)\)

#include <bits/stdc++.h>
typedef long long ll;
using namespace std;
const int N = 5e5 + 7;

struct Graph {
    vector<int> e[N];
    
    inline void insert(int u, int v) {
        e[u].emplace_back(v);
    }
} G;

vector<int> num[N], dis[N];

ll cnt[N];
int siz[N], mxsiz[N], nfa[N], ndep[N], fir[N], sec[N];
bool vis[N];

int n, tp, root, ans;

void getroot(int u, int f, int Siz) {
    siz[u] = 1, mxsiz[u] = 0;

    for (int v : G.e[u])
        if (!vis[v] && v != f)
            getroot(v, u, Siz), siz[u] += siz[v], mxsiz[u] = max(mxsiz[u], siz[v]);

    mxsiz[u] = max(mxsiz[u], Siz - siz[u]);

    if (!root || mxsiz[u] < mxsiz[root])
        root = u;
}

void dfs(int u, int f, int d, int rt) {
    siz[u] = 1;

    if (dis[u][d] == num[rt].size())
        num[rt].emplace_back(1);
    else
        ++num[rt][dis[u][d]];

    for (int v : G.e[u])
        if (!vis[v] && v != f)
            dis[v].emplace_back(dis[u][d] + 1), dfs(v, u, d, rt), siz[u] += siz[v];
}

inline void calc(int u, int op) {
    if (fir[u] == num[u].size() || sec[u] == num[u].size())
        return;

    cnt[fir[u] + sec[u]] += op * (fir[u] == sec[u] ? 1ll * num[u][fir[u]] * (num[u][fir[u]] - 1) / 2 :
        1ll * num[u][fir[u]] * num[u][sec[u]]);
}

inline void maintain(int u) {
    while (fir[u] < num[u].size() && !num[u][fir[u]])
        ++fir[u];

    while (sec[u] < num[u].size() && !(num[u][sec[u]] - (sec[u] == fir[u])))
        ++sec[u];

    calc(u, 1);
}

void build(int u) {
    vis[u] = true, dis[u].emplace_back(0), dfs(u, 0, ndep[u], u), maintain(u);

    for (int v : G.e[u])
        if (!vis[v])
            root = 0, getroot(v, u, siz[v]), ndep[root] = ndep[nfa[root] = u] + 1, build(root);
}

inline void update(int x) {
    for (int u = x, d = ndep[x]; u; u = nfa[u], --d)
        calc(u, -1), --num[u][dis[x][d]], maintain(u);

    while (ans <= n && !cnt[ans])
        ++ans;
}

signed main() {
    scanf("%d%d", &n, &tp);

    for (int i = 1; i < n; ++i) {
        int u, v;
        scanf("%d%d", &u, &v);
        G.insert(u, v), G.insert(v, u);
    }

    getroot(1, 0, n), build(root), ans = 1;
    ll lstans = 0;

    for (int i = 1; i <= n - 2; ++i) {
        ll x;
        scanf("%lld", &x);

        if (tp)
            x ^= lstans;

        update(x);
        printf("%d %lld\n", ans, lstans = cnt[ans]);
    }

    return 0;
}
posted @ 2025-01-04 14:58  wshcl  阅读(101)  评论(0)    收藏  举报