[ZJOI2018] 胖 题解

前言

题目链接:洛谷

题意简述

给定一个包含 \(n\) 个节点的链状无向图,编号为 \(1\sim n\),节点 \(i\)\(i+1\) 之间的边权为 \(w_i\)。此外,有 \(m\) 次互相独立的询问,每个询问添加若干条从 \(0\) 号点到某些点的双向边,每条边具有给定的长度。

对于每次询问求,使用 Bellman-Ford 算法计算过程中,每轮的 \(t\) 之和,其中 \(t\) 是本轮更新的节点数量。

\(n, m \leq 2\times10^5\)

题目分析

解决一维问题的一个重要的 observation 是找到连续一段区间。这里我们不难发现,每条额外新增的边能够更新到的点是连续的一段区间。那么我们情不自禁地转换计数视角,要求每轮更新结点数量之和,就是每个点被更新次数之和,就是每条新增边更新的结点个数之和。问题转化为求出每一条新增边更新到点的区间,进一步,我们只需要快速求出区间的左右端点。左右本质相同,以下仅考虑求出区间左端点。

不难想到二分出这个左端点,考虑 check 一个 \(L\) 是否合法。从新增边连向的点 \(p\) 更新到 \(L\) 需要 \(d=p-L\) 轮。在这 \(d\) 轮中,\([L-d,L+d)\) 中连着新增边的点也会更新到 \(L\),想要 \(p\) 顺利更新到 \(L\),就需要保证 \([L-d,L+d)\) 更新后,\(p\) 依旧可以进行松弛。对 \(w\) 做前缀和为 \(l\),令 \(e_i\) 表示连向点 \(i\) 的新增边长度(没有则为 \(+\infty\)),那么上述条件可以表示为 \(\forall i\in[L-d,L],l_L-l_i+e_i\gt l_p-l_L+e_p\)\(\forall i\in[L,L+d),l_i-l_L+e_i\gt l_p-l_L+e_p\),发现只需要维护 \(e_i-l_i\)\(e_i+l_i\) 的区间最小值即可快速判断。使用 ST 表即可。需要注意到的是,如果两个点在同一时刻尝试松弛某一个点,且尝试松弛的距离相等,那么只有左边的可以更新到它。这个小小特判即可。

ST 表预处理 \(\mathcal{O}(n\log n)\),所以先要离散化,令 \(c\) 为新增边的条数,做到 \(\mathcal{O}(c\log c)\) 预处理和 \(\mathcal{O}(c\log n \log c)\) 的时间复杂度。

代码

#define NDEBUG

#include <cstdio>
#include <iostream>
#include <cassert>
#include <algorithm>
using namespace std;

const int N = 200010;

int n, m;

using lint = long long;

const int lgN = __lg(N) + 1;

struct ST {
    lint st[lgN][N];
    inline void set(int p, lint v) {
        st[0][p] = v;
    }
    void init(int n) {
        for (int k = 1; k < lgN; ++k)
            for (int i = 1; i + (1 << k) - 1 <= n; ++i)
                st[k][i] = min(st[k - 1][i], st[k - 1][i + (1 << (k - 1))]);
    }
    lint query(int l, int r) {
        assert(l <= r);
        int p = __lg(r - l + 1);
        return min(st[p][l], st[p][r - (1 << p) + 1]);
    }
} st[2];

lint pre[N];
int a[N], l[N];
int tmp[N], rev[N];

int main() {
    scanf("%d%d", &n, &m);
    
    for (int i = 1, w; i < n; ++i) {
        scanf("%d", &w);
        pre[i + 1] = pre[i] + w;
    }
    for (int k; m--; ) {
        scanf("%d", &k);
        for (int j = 1; j <= k; ++j) {
            scanf("%d%d", &a[j], &l[j]);
            tmp[j] = a[j], rev[a[j]] = l[j];
        }
        sort(tmp + 1, tmp + k + 1);
        for (int i = 1; i <= k; ++i) {
            a[i] = tmp[i];
            l[i] = rev[tmp[i]];
            st[0].set(i, l[i] - pre[a[i]]);
            st[1].set(i, l[i] + pre[a[i]]);
        }
        st[0].init(k), st[1].init(k);
        
        auto queryL = [&] (int i) -> int {
            int x = a[i];
            
            auto check = [=] (int mid) -> bool {
                lint dis = l[i] + pre[x] - pre[mid];
                int dt = x - mid;
                int tl = lower_bound(tmp + 1, tmp + k + 1, max(1, mid - dt)) - tmp;
                int tr = upper_bound(tmp + 1, tmp + k + 1, mid) - tmp - 1;
                if (tl <= tr && st[0].query(tl, tr) + pre[mid] <= dis)
                    return false;
                tl = tr + (!tr || tmp[tr] != mid), tr = i - 1;
                assert(tl == lower_bound(tmp + 1, tmp + k + 1, mid) - tmp);
                assert(tr == upper_bound(tmp + 1, tmp + k + 1, x - 1) - tmp - 1);
                if (tl <= tr && st[1].query(tl, tr) - pre[mid] <= dis)
                    return false;
                return true;
            };
            
            int L = 1, R = x - 1, ans = x, mid;
            while (L <= R) {
                mid = (L + R) >> 1;
                if (check(mid))
                    ans = mid, R = mid - 1;
                else
                    L = mid + 1;
            }
            return ans;
        };
        
        auto queryR = [&] (int i) -> int {
            int x = a[i];
            
            auto check = [=] (int mid) -> bool {
                lint dis = l[i] + pre[mid] - pre[x];
                int dt = mid - x;
                int tl = lower_bound(tmp + 1, tmp + k + 1, mid) - tmp;
                int tr = upper_bound(tmp + 1, tmp + k + 1, min(mid + dt - 1, n)) - tmp - 1;
                int ttt = tr;
                // don't include mid + dt
                if (tl <= tr && st[1].query(tl, tr) - pre[mid] <= dis)
                    return false;
                tr = tl - (tl > k || tmp[tl] != mid), tl = i + 1;
                assert(tl == lower_bound(tmp + 1, tmp + k + 1, x + 1) - tmp);
                assert(tr == upper_bound(tmp + 1, tmp + k + 1, mid) - tmp - 1);
                if (tl <= tr && st[0].query(tl, tr) + pre[mid] <= dis)
                    return false;
                if (mid + dt <= n) {
                    if (ttt + 1 <= k && tmp[ttt + 1] <= mid + dt) ++ttt;
                    assert(ttt == upper_bound(tmp + 1, tmp + k + 1, min(mid + dt, n)) - tmp - 1);
                    if (tmp[ttt] == mid + dt) {
                        if (st[1].st[0][ttt] - pre[mid] < dis)
                            return false;
                    }
                }
                return true;
            };
            
            int L = x + 1, R = n, ans = x, mid;
            while (L <= R) {
                mid = (L + R) >> 1;
                if (check(mid))
                    ans = mid, L = mid + 1;
                else
                    R = mid - 1;
            }
            return ans;
        };
        
        lint ans = 0;
        for (int i = 1; i <= k; ++i) {
            int L = queryL(i), R = queryR(i);
            ans += R - L + 1;
        }
        printf("%lld\n", ans);
    }
    return 0;
}
posted @ 2025-03-06 12:45  XuYueming  阅读(29)  评论(0)    收藏  举报