ABC273 - Ex 题解

ABC273 Ex - Inv(0,1)ving Insert(1,0)n 题解

abc273_h

引入

首先题目描述的东西就是 \(\operatorname{Stern-Brocot\ Tree}\),不会也没有关系,你只需要知道以下几个性质(作者自己写的,建议无论大佬您会不会,都看下,因为后面可能会用到有关概念):

  1. 树上任意一个二元组,均有两元素互质
  2. 如果把树上每个二元组理解成一个既约分数,分子为第一个元素且分母为第二个元素(特别的,\(\frac{1}{0}\) 为正无穷),以下若无特殊说明,均按照此方式理解
  3. 按照中序遍历遍历整棵树,先访问到的元素一定小于后访问到的元素

(如果您想要证明,还是去看上面的链接吧)。

约定

  1. 把每个节点上的分数和一个线段对应,即 \((\frac{a}{b},\frac{c}{d})\) 与分数 \(\frac{a+c}{b+d}\) 对应(意义是该节点子树内的所有节点都在这个值域里)。线段表示和分数表示均指这个节点。
  2. 根据不合法元素(\(\gcd\) 不为 \(1\))的位置,可以把序列分成不交的若干段,每一段都可以单独处理,因此,以下讨论的序列全部元素合法。

正文

首先我们考虑,对于一个序列 \(p_{1},p_2,\dots,p_n\)\(p_i\) 是一个既约分数)。原问题求的问题记为 \(p\) 的所有子区间中节点在 \(\operatorname{Stern-Brocot\ Tree}\) 上生成的虚树点权和(其中一个点的点权是他的在原树上深度减去他虚树上的父亲在原树上的深度)。

注意到 \(\frac{0}{1},\frac{1}{0}\) 并不会对虚树点权造成贡献,这部分细节读者可以自行思考。

枚举每个区间,暴力建出虚树是不可接受的。但是我们会发现一个事情,对于任意子区间生成的虚树,它上面的所有节点一定都在整个区间生成的虚树上出现过

这启发我们直接对整个区间生成虚树,并统计每个点的点权贡献的次数

为了方便,我们将该序列排序,并记录每个元素在原序列中的位置。(\(pos_i\) 表示排序后排名为 \(i\) 的位置在原序列中的下标)。以后说的序列都指的是新序列。那么每个 \(\operatorname{Stern-Brocot\ Tree}\) 上的节点对应的一定是新序列的一段区间(可以为空)。

建立虚树并计算点权

假设我们正在某个节点 \((\frac{a}{b},\frac{c}{d})\) 上,当前序列的子区间是 \([l,r](l\le r)\),我们一个分界点 \(lm\),以及一个分界点 \(rm\),表示 \([l,lm]\) 中的元素都小于 \(\frac{a+c}{b+d}\),且 \([rm,r]\) 中元素都大于 \(\frac{a+c}{b+d}\)(显然 \((lm,rm)\) 中的元素,如果他们存在的话,都是 \(\frac{a+c}{b+d}\))。

我们对于 \([l,lm],[rm,r]\) 分别考虑,因为两者基本相同,所以我们这里以 \([l,lm]\) 为例。

我们计算 \([l,lm]\) 能够一直往左跳的深度,记为 \(k\),显然 \(k\) 仅与 \(lm\) 位置上的元素有关,所以可以根据不等式:

\[p_{lm}< \frac{a+kc}{b+kd} \]

\(O(1)\) 计算出最大的 \(k\)

解释:上面的不等式右边代表的是目标区间的右端点。

复杂度分析:首先给出结论:节点个数的复杂度上限是 \(\Theta(\sum_{i=1}^{n}\log\max\{a_i,b_i\})\)。我们正向不容易分析复杂度,于是可以反向考虑。我们一次上跳从 \((\frac{a}{b},\frac{c}{d}),(p_i=\frac{a+c}{b+d})\) 这个点往上跳,假设当前点是其父亲的右儿子,会让 \(a\leftarrow a-kc,b\leftarrow b-kd\),相当于 \((a+c)\leftarrow a\bmod c+c\)\(b+d\) 同理),类似求 \(\gcd\) 的复杂度分析,最多会跳 \(\Theta(\log\max\{a_i,b_i\})\) 次。

于是,点权就是从父亲到他跳过的点数

求出每个点的贡献次数

\(S(n)=\sum_{i=1}^{n}i=\frac{n\times(n+1)}{2}\)

我们假设能更新到当前节点的元素 \(pos\) 排序后是 \(x_1,x_2,\dots,x_m\),如果我们额外令 \(x_0=0,x_{m+1}=n+1\),那么方案数就是 \(S(n)-\sum_{i=1}^{m+1}S(x_{i}-x_{i-1}-1)\),意思即使全部区间减去不包含该序列中任何一个点的区间。

考虑如果我们用 std::set 来维护这些下标集合,那么我们可以做到快速加、删元素(单次 \(O(\log n)\))并维护我们需要的“贡献次数”。

每次我们会将这个集合中间的(\([lm+1,rm-1]\))删去,并把其他的分成两份,一份是 \([l,lm]\),另一份是 \([rm,r]\),我们可以通过启发式分裂(自己 yy 的名字)做到 \(O(n\log^2 n)\)

自此,本题全部解决。

贴上代码:

#include <bits/stdc++.h>
using namespace std;
template <typename F> inline void read(F &n) {
	F w = 1; n = 0;
	char ch = getchar();
	while (!isdigit(ch) && ch != EOF) { if(ch == '-') w = -1; ch = getchar(); } 
	while (isdigit(ch) && ch != EOF) { n = (n << 1) + (n << 3) + (ch & 15); ch = getchar(); }
	n *= w;
}
template<typename T, typename ...L> inline void read(T &n,L &...l) { read(n); read(l...); }
typedef long long ll;

const int N = 1e5 + 5, mo = 998244353;
int n, id[N], tot = 0;
struct fraction {
    ll x, y;
    fraction() { x = 0, y = 1; }
    fraction(ll x, ll y) : x(x), y(y) {}
    friend bool operator<(const fraction &x, const fraction &y) {
        return (__int128)x.x * y.y < (__int128)y.x * x.y;
    }
    friend bool operator==(const fraction &x, const fraction &y) {
        return x.x == y.x && x.y == y.y;
    }
} p[N];
int gcd(int a, int b) { return b == 0 ? a : gcd(b, a % b); }
ll ans = 0, all;
int L, R;
ll S(ll x) { return x * (x + 1) / 2 % mo; }
void pls(ll &x, ll y) { x = x + y >= mo ? x + y - mo : x + y; }
void mns(ll &x, ll y) { x = x < y ? x - y + mo : x - y; }

int sL, sR; //for SET
struct SET {
    ll cnt;
    set<int> s;
    SET() { cnt = 0; s.clear(); }
    void insert(const int &x) {
        auto it = s.insert(x).first;
        int l = sL, r = sR;
        if (it != s.begin()) l = *prev(it);
        ++it; if (it != s.end()) r = *it;
        pls(cnt, S(r - l - 1));
        mns(cnt, S(x - l - 1));
        mns(cnt, S(r - x - 1));
    }
    void erase(const int &x) {
        auto it = s.find(x);
        int l = sL, r = sR;
        if (it != s.begin()) l = *prev(it);
        ++it; if (it != s.end()) r = *it;
        pls(cnt, S(x - l - 1));
        pls(cnt, S(r - x - 1));
        mns(cnt, S(r - l - 1));
        s.erase(x);
    }
};

int findL(int A, int B, const fraction &x) {
    int mid, ans = A - 1;
    while (A <= B) {
        mid = (A + B) >> 1;
        if (p[id[mid]] < x) ans = mid, A = mid + 1;
        else B = mid - 1;
    }
    return ans;
}
int findR(int A, int B, const fraction &x) {
    int mid, ans = B + 1;
    while (A <= B) {
        mid = (A + B) >> 1;
        if (x < p[id[mid]]) ans = mid, B = mid - 1;
        else A = mid + 1;
    }
    return ans;
}
void calc(int l, int r, SET &s, fraction x, fraction y, ll d) {
    if (l > r) return;
    pls(ans, s.cnt * (d % mo) % mo);
    fraction mf(x.x + y.x, x.y + y.y);
    int lm = findL(l, r, mf), rm = findR(l, r, mf);
    for (int i = lm + 1; i < rm; ++i) s.erase(id[i]);
    SET t;
    if (lm - l < r - rm) {
        for (int i = l; i <= lm; ++i) t.insert(id[i]), s.erase(id[i]);
        swap(s, t);
    } else {
        for (int i = rm; i <= r; ++i) t.insert(id[i]), s.erase(id[i]);
    }
    ll down;
    if (l <= lm) {
        down = (y.x * p[id[lm]].y - y.y * p[id[lm]].x - 1) / (x.y * p[id[lm]].x - x.x * p[id[lm]].y);
        assert(down > 0);
        calc(l, lm, s, x, fraction(x.x * down + y.x, x.y * down + y.y), down);
    }
    if (rm <= r) {
        down = (p[id[rm]].x * x.y - p[id[rm]].y * x.x - 1) / (p[id[rm]].y * y.x - p[id[rm]].x * y.y);
        assert(down > 0);
        calc(rm, r, t, fraction(x.x + y.x * down, x.y + y.y * down), y, down);
    }
}

void solve(int L, int R) {
    ::L = L, ::R = R;
    all = S(R - L + 1);
    tot = 0;
    sL = L - 1, sR = R + 1;
    SET st;
    for (int i = L; i <= R; ++i) {
        if (min(p[i].x, p[i].y) < 1) continue;
        id[++tot] = i;
        st.insert(i);
    }
    sort(id + 1, id + 1 + tot, [&](int i, int j) {
        return p[i] < p[j];
    });
    calc(1, tot, st, fraction(0, 1), fraction(1, 0), 1);
}

int main() {
    read(n);
    for (int i = 1; i <= n; ++i) {
        read(p[i].x, p[i].y);
    }
    int las = 0;
    for (int i = 1; i <= n; ++i) {
        if (max(p[i].x, p[i].y) < 1 || gcd(p[i].x, p[i].y) != 1) {
            solve(las + 1, i - 1);
            las = i;
        }
    }
    if (las < n) solve(las + 1, n);
    printf("%lld\n", ans);
    return 0;
}
posted @ 2022-10-18 20:14  lazytag  阅读(76)  评论(6)    收藏  举报