根号分治及例题

原文 暴力美学——浅谈根号分治

其实以下内容就是自己总结一下上文中的例题。

概述

根号分治是一种十分暴力且巧妙的策略,需要我们观察大数据和小数据的特点,通常他们都有各自的解法,但是每个解法又不能适应所有数据,于是我们就想到把数据按照某个界限分类。

(干说还是说不清,直接看下面 4 到题吧)

CF 1207F

题目大意

对于一个序列 \(A\) 给定如下两种操作:

  • 1 x y 表示 \(A_x\) 自增 \(y\)
  • 2 x y 计算输出所有模 \(x\)\(y\) 的下标上的数的和。

保证所有操作合法且 \(x < 5 \times 10^5\)

解题思路

我们可以很容易的想到这道题的暴力做法:既对于每个询问,我们从位置 \(y\) 开始每次眺 \(x\) 个数求和。对于这个做法,我们每次询问都要跳 \(\lfloor {5 \times 10^5 \over x} \rfloor\) 次,我们发现随着 \(x\) 的增大,这个次数会越来越小。我们暂且将这个暴力视为对于大数据的解法。

那对于小数据我们该如何去解决呢?由于数据小,我们干脆对于每次修改都维护所有小数据的答案,等询问到小数据的时候就直接输出:

{
    a[x] += y;
    // L 是大数据与小数据之间的界限,我们等下来确定。
    for (int i = 1; i < L; i++) {
        ans[i][x % i] += y;
    }
}

我们现在已经分别确定了两种数据的解法,这时候就应该来确定划分两种数据的合适界限了。界限的确定通常由时间复杂度来确定。若我们将界限设为 \(L\),则对于所有修改有一个 \(O(L)\) 的时间复杂度,对于大数据每次询问还有一个 \(O({5 \times 10^5 \over L})\) 的时间复杂度,对于 \(q\) 次操作,时间复杂度就是 \(O(q(L + {5 \times 10 ^ 5 \over L}))\),由基本不等式可以确定当 \(L = \sqrt{5 \times 10^5}\) 时时间复杂度是最优的,而且此时给于小数据开的额外的空间也合适。

至此两种都可以称得上是暴力的解法(一种直接模拟,一种直接预处理)拼成了这道 2100 分的题:

constexpr int Q = 5e5, X = 5e5, Y = 1e3, L = 710; // sqrt(Q)
 
int a[X + 5];
int ans[L][L];
 
void solve()
{
    int q = 0;
    std::cin >> q;
    while (q--) {
        int op = 0, x = 0, y = 0;
        std::cin >> op >> x >> y;
        if (op == 1) {
            a[x] += y;
            for (int i = 1; i < L; i++) {
                ans[i][x % i] += y;
            }
        }
        else {
            int sum = 0;
            if (x < L) {
                sum = ans[x][y];
            }
            else {
                for (int i = y; i <= X; i += x) {
                    sum += a[i];
                }
            }
            std::cout << sum << '\n';
        }
    }
    return;
}

CF 710 D

题目大意

给出两个递增的等差序列(以首项和公差的形式给出),问在 \([l, r]\) 区间中,有多少整数同时是两个等差序列中的数。

解题思路

首先,公差越大时,包含在 \([l, r]\) 中的数越少,具体的,当公差为 \(K\) 时,这个等差序列包含在区间 \([l, r]\) 中的数大约有 \((r - l) \over K\) 个。所以当公差大于某个阈值 \(L\) 时,我们就暴力枚举公差较大的那个序列在给定区间中的值,然后检查这个值是否也在另一个等差序列中。

{
    // k1 k2 是公差,且 k1 是较大的那个
    for (i64 i = (l - a1 + k1 - 1) / k1 * k1 + a1; i <= r; i += k1) {
        if ((i - a2) % k2 == 0) {
            sum++;
        }
    }
}

这个暴力的时间复杂度是 \(O(r - l) \over L\)

现在只需要考虑公差小于阈值 \(L\) 时应给如何做。此时我们考虑当有一对整数 \(<i_1, i_2>\) 使得 \(a_1 + i_1k_1 = a_2 + i_2k_2\)\(k_2 \leq k_1\)) 时,这对整数有什么性质。从等式入手,我们发现可以在等式两边同时加上公差的最小公倍数 \(lcm(k_1, k_2) = {k_1k_2 \over gcd(k_1, k_2)}\),这样一来,在保持等式成立的同时,我们获得了一个新的整数对 \(<i_1 + {k2 \over gcd(k_1, k_2)}, i_2 + {k1 \over gcd(k_1, k_2)}>\),而且 \({k2 \over gcd(k_1, k_2)} \leq k1\)。这就意味着,只要我们在第一个等差序列(公差较大)中连续取 \(k2 \over gcd(k_1, k_2)\) 个数并判断是否是两个等差序列公共的,然后我们就可以根据已经知道的满足条件的数,判断所给区间有多少个满足条件的数了。

{
    i64 lcm = k1 * k2 / std::__gcd(k1, k2);
    for (i64 i = a1; i < std::min(r + 1, a1 + lcm); i += k1) {
        if ((i - a2) % k2) {
            continue;
        }

        i64 L = std::max(0ll, (l - i + lcm - 1) / lcm) * lcm + i;
        sum += (r - L) / lcm + 1;
    }
}

这种做法的复杂度我们取 \(O(L)\),因为遍历的数的个数 \(k2 \over gcd(k_1, k_2)\) 是不大于 \(k_1\) 的,而 \(k_1\) 是不大于阈值 \(L\) 的。

综合两种做法,时间复杂度就是 \(O(\max(L, {r - l \over L}))\),当 \(L = \sqrt{r - l}\) 时最优是 \(O(\sqrt{r - l})\)

这样我们通过根号分治以比较优秀的时间复杂度通过了这道 2500 的数论:

i64 solve()
{
    i64 k1 = 0, a1 = 0, k2 = 0, a2 = 0, l = 0, r = 0;
    std::cin >> k1 >> a1 >> k2 >> a2 >> l >> r;
    l = std::max({ l, a1, a2 });
    if (l > r) {
        return 0;
    }

    if (k1 < k2) {
        std::swap(k1, k2);
        std::swap(a1, a2);
    }
    
    i64 sum = 0;
    if (k1 * k1 >= (r - l + 1)) {
        for (i64 i = (l - a1 + k1 - 1) / k1 * k1 + a1; i <= r; i += k1) {
            if ((i - a2) % k2 == 0) {
                sum++;
            }
        }
    }
    else {
        i64 lcm = k1 * k2 / std::__gcd(k1, k2);
        for (i64 i = a1; i < std::min(r + 1, a1 + lcm); i += k1) {
            if ((i - a2) % k2) {
                continue;
            }

            i64 L = std::max(0ll, (l - i + lcm - 1) / lcm) * lcm + i;
            sum += (r - L) / lcm + 1;
        }
    }

    return sum;
}

ARC 052 D

题目大意

给定两个正整数 \(K\)\(M\)\(1 \leq K, M \leq 10^{10}\)),问在区间 \([1, M]\) 中有多少数 \(N\) 满足 \(N\equiv f(N)(modK)\),其中 \(f(N)\) 的值是数 \(N\) 各个位上的数之和,比如 \(f(142) = 1 + 4 + 2 = 7\)

\(N\) 的值域很大,但是 \(f(N)\) 呢?可以发现 \(f(N) \in [1, 90]\),于是满足条件的 \(N\) 一定有 \(N \% K < 90\),这个是对于任何数据都成立的。

有了上述性质,对于大于阈值 \(L\) 的暴力求法就呼之欲出了,既遍历模数在 \([0, 90)\) 的所有数然后判断是否满足条件就好了。

{
    // parse(n) 就是 f(n)
    for (int st = 1; st <= 90; st++) {
        for (i64 n = st; n <= m; n += k) {
            ans += (n - parse(n)) % k == 0;
        }
    }
}

\(len\)\(M\) 的十进制位数,则时间复杂度是 \(O(len{M \over L})\)

对于小于阈值的解法,就涉及到了数位DP,这里不做讨论只给出代码:

{
    i64 p9[] = { 0, 9, 99, 999, 9999, 99999, 999999, 9999999, 99999999, 999999999, 9999999999 };
    std::vector<int> dig;
    i64 len = m;
    while (len) {
        dig.push_back(len % 10);
        len /= 10;
    }
    len = dig.size();
    for (int i = 0; i < len; i++) {
        p9[i] %= k;
    }
    
    // f[i][j][0] 的含义是考虑到第 i 位(个位是第 0 位),(n - parse(n)) % k 的值是 j 且不考虑是否超出范围的 n 的数量
    // f[i][j][1] 的含义是考虑到第 i 位(个位是第 0 位),(n - parse(n)) % k 的值是 j 且考虑是否超出范围的 n 的数量
    f[0][0][0] = 10;
    f[0][0][1] = dig[0] + 1;
    for (int i = 1; i < len; i++) {
        for (int j = 0; j < k; j++) {
            if (f[i - 1][j][0] == 0) {
                continue;
            }
            for (int d = 0; d < 10; d++) {
                f[i][(j + d * p9[i]) % k][0] += f[i - 1][j][0];
            }
            for (int d = 0; d < dig[i]; d++) {
                f[i][(j + d * p9[i]) % k][1] += f[i - 1][j][0];
            }
            f[i][(j + dig[i] * p9[i]) % k][1] += f[i - 1][j][1];
        }
    }
    // 减掉 n = 0 对答案的贡献
    ans = f[len - 1][0][1] - 1;
}

时间复杂度是 \(lenL\)

综合两种解法,时间复杂度就是 \(O(len\max(L, {M \over L}))\),取 \(L = \sqrt{M}\) 是时间复杂度最优为 \(O(len\sqrt{M})\)

i64 f[11][L][2];

int parse(i64 x) {
    int res = 0;
    while (x) {
        res += x % 10;
        x /= 10;
    }
    return res;
}

void solve()
{
    i64 k = 0ll, m = 0ll;
    std::cin >> k >> m;
    i64 ans = 0;
    if (k > L  || k * k >= m) {
        for (int st = 1; st <= 90; st++) {
            for (i64 n = st; n <= m; n += k) {
                ans += (n - parse(n)) % k == 0;
            }
        }
    }
    else {
        i64 p9[] = { 0, 9, 99, 999, 9999, 99999, 999999, 9999999, 99999999, 999999999, 9999999999 };
        std::vector<int> dig;
        i64 len = m;
        while (len) {
            dig.push_back(len % 10);
            len /= 10;
        }
        len = dig.size();
        for (int i = 0; i < len; i++) {
            p9[i] %= k;
        }
        
        f[0][0][0] = 10;
        f[0][0][1] = dig[0] + 1;
        for (int i = 1; i < len; i++) {
            for (int j = 0; j < k; j++) {
                if (f[i - 1][j][0] == 0) {
                    continue;
                }
                for (int d = 0; d < 10; d++) {
                    f[i][(j + d * p9[i]) % k][0] += f[i - 1][j][0];
                }
                for (int d = 0; d < dig[i]; d++) {
                    f[i][(j + d * p9[i]) % k][1] += f[i - 1][j][0];
                }
                f[i][(j + dig[i] * p9[i]) % k][1] += f[i - 1][j][1];
            }
        }
        ans = f[len - 1][0][1] - 1;
    }
    std::cout << ans << '\n';
}

以上三道(都可以认为是数论的题吧)的根号分治解法有相似之处,都是观察到某种暴力解法所要的次数与数据规模呈反比,于是考虑设定阈值让大于这个阈值的用这个暴力解法,对于小于这个阈值的,我们另找解法,可能也是暴力,也可能需要用数据结构或者其他算法解决。

而且第三道题少见的数据范围 \(10^{10}\) 次方似乎也在暗示用根号分治。

Luogu P5901 [IOI 2009] Regions

题目大意

在一棵有 \(N\) 个节点的树上染 \(R\) 种颜色,给出 \(Q\) 组询问 \(r_1, r_2\),问有多少对有序点对 \(<u, v>\) 使得 \(u\) 的颜色是 \(r_1\)\(v\) 的颜色是 \(r_2\)\(u\)\(v\) 的祖先。询问强制在线。

\(N,Q \leq 2 \times 10^5, R \leq 2.5 \times 10^4\)

解题思路

用根号分治解决图上的问题一般是基于这个性质:一张有 \(M\) 条边的无向图中度数大于 \(K\) 的点的个数最多为 \(2M \over K\)(分母是这张图中所有点的度数和)。也就是说大于某个度数的点的个数与度数的大小呈反比。

我们规定节点数大于阈值 \(L\) 的颜色为主色,其他的颜色就是次色。于是根据上面的性质,主色最多有 \(2N \over L\) 个。我们可以试着把涉及到主色的询问提前预处理出来,具体分为以下两种情况:

  1. 询问 \(r_1, r_2\)\(r_1\) 是主色
  2. 询问 \(r_1, r_2\)\(r_2\) 是主色,且 \(r_1\) 是次色(\(r_1\) 是主色的情况包含在第一种情况,要在第二种情况里算一遍也是可以的,只不过没必要)

在第一种情况中颜色是主色的节点是当作祖先的,于是我们只用在 dfs 的过程中记录从根节点到当前节点的最短路径中各种主色节点各有多少个,然后再加到预处理的答案中去就好了。

第二种情况中颜色是主色的节点是当作子节点的,于是对于一个点,我们关心以这个点为根的子树里各种主色节点各有多少个,这个东西我们可以用差值来维护。具体的,我们对所有主色维护一个 cnt 表示 dfs 到当前节点的过程中各种主色节点一共经过了多少个(注意这与第一种情况中我们所要记录的从根节点到当前节点的最短路径中各种主色节点的数量有所不同)。于是对于任意一个节点,在遍历到这个节点时我们先减一遍 cnt (此时不包含子树的信息),等遍要从这个节点回溯回去时,我们有加一遍 cnt(此时包含子树的信息)。这样一减一加剩下的就是子树的信息。

然后对于所询问的两种颜色都是次色的情况,直接在 dfs 序中遍历这两种颜色的点就好了,遍历的次数是 \(2L\),时间复杂度就是 \(O(\max(N{N \over L},QL))\),取 \(L = \sqrt{N}\) 就够了。

具体实现直接看代码:

std::vector<int> g[N + 5];
int col[N + 5], num[R + 5], h[R + 5], cnt;

unsigned uBig[L][R + 5], vBig[R + 5][L];

int above[L]; // 统计祖先中的大点
int below[L]; // 利用差值统计子树中的大点
int siz[N + 5], dfn[N + 5], tim;
int fst[R + 5], jump[N + 5]; // 每种颜色第一次出现的位置 下一个同色的节点的位置
void dfs(int cur) {
    dfn[++tim] = cur;
    siz[cur] = 1;
    if (h[col[cur]]) {
        above[h[col[cur]]]++;
        below[h[col[cur]]]++;
    }
    else {
        for (int i = 1; i <= cnt; i++) {
            vBig[col[cur]][i] -= below[i];
        }
    }
    for (int i = 1; i <= cnt; i++) {
        uBig[i][col[cur]] += above[i];
    }

    for (int i = 0; i < g[cur].size(); i++) {
        dfs(g[cur][i]);
        siz[cur] += siz[g[cur][i]];
    }

    if (h[col[cur]]) {
        above[h[col[cur]]]--;
    }
    else {
        for (int i = 1; i <= cnt; i++) {
            vBig[col[cur]][i] += below[i];
        }
    }

    return;
}

void solve()
{
    int n = 0, r = 0, q = 0;
    std::cin >> n >> r >> q;
    std::cin >> col[1];
    for (int i = 2; i <= n; i++) {
        int fa = 0;
        std::cin >> fa >> col[i];
        g[fa].push_back(i);
        num[col[i]]++;
    }
    for (int i = 1; i <= r; i++) {
        if (num[i] * num[i] >= n) {
            h[i] = ++cnt;
        }
    }
    
    dfs(1);
    for (int i = 1; i <= r; i++) {
        fst[i] = Inf;
    }
    for (int i = n; i; i--) {
        jump[i] = fst[col[dfn[i]]];
        fst[col[dfn[i]]] = i;
    }

    while (q--) {
        int u = 0, v = 0;
        std::cin >> u >> v;
        unsigned ans = 0;
        if (h[u]) {
            ans = uBig[h[u]][v];
        }
        else if (h[v]) {
            ans = vBig[u][h[v]];
        }
        else {
            int add = 0;
            std::stack<int> R;
            int pu = fst[u];
            int pv = fst[v];
            while (pv != Inf) {
                while (pv > pu) {
                    // dfn[pv] 在 dfn[pu] 的子树内的话
                    if (pv < pu + siz[dfn[pu]]) {
                        R.push(pu + siz[dfn[pu]]);
                        add++; // 满足条件的祖先加一
                    }
                    pu = jump[pu];
                }
                ans += add;
                pv = jump[pv];
                // 跳出子树满足条件的子树就减
                while (not R.empty() && R.top() <= pv) {
                    R.pop();
                    add--;
                }
            }
        }
        std::cout << ans << std::endl;
    }
    return;
}

posted @ 2025-04-07 17:42  Young_Cloud  阅读(65)  评论(0)    收藏  举报