P7334 [JRKSJ R1] 吊打 题解

题目分析

其实这道题只要联想到线段树就很简单了,注意到开根号和平方互为逆运算,所以我们不妨记录一个数据,让其表示最后需要平方的次数(平方次数-开根号次数),这很好维护,对于区间平方操作就是区间内都 \(+1\),开方就是 \(-1\),但是也有特殊情况。

在没有平方时开方就需要特殊讨论,不计入上述“开根号次数”,因为“下取整”这一操作影响到了后面的操作。

所以我们在实现线段树的时候再维护一个信息,记录那些“没有平方时的开方”。

最后把整颗树遍历一遍,得到每个叶子节点的信息后,从叶子节点向上传,记录总和即可。

开根号很好搞,模拟就行,最多开个 \(5\) 次就完事了。

需要注意的是,对于平方操作不能进行模拟,不然会有 \(O(nm)\) 的复杂度。

注意到,由费马小定理 \(a^{p-1}\equiv 1 \pmod p\) 可得:

\[\begin{aligned} a^k &\equiv a^{p-1-(p-1)+k} \pmod p & \\ &\equiv a^{p-1}\cdot a^{k-(p-1)} \pmod p & \\ &\equiv a^{k-(p-1)} \pmod p & \\ \end{aligned} \]

这样的操作进行 \(n\) 次就有 \(a^k \equiv a^{k-n(p-1)} \equiv a^{k \mod (p-1)} \pmod p\)

所以我们只要对指数取模即可,假设题中对一个数进行了 \(c\) 次平方操作,那么最后计算 \(a^{2^c \mod (p-1)}\) 即可。

每次修改均摊复杂度 \(O(n\log n)\),最后查询复杂度为 \(O(n)\)\(O(n(\log n + \log p + \log m))\) 之间(前者是全部操作都是开根号,后者是全部操作都是平方)。

最后跑了 \(3.09\) 秒,加点小优化可以跑 \(1.91\) 秒,时间很充裕。

代码实现

#include <algorithm>
#include <cstdio>
#include <cmath>
#include <iostream>
using namespace std;
#define ll long long
#define N 200010
ll szmax(ll num1, ll num2) { return num1 < num2 ? num2 : num1; }

int a[N];

ll qpow(ll base, ll exp, ll mod = 998244353)
{
    if (base == 1)
        return 1;
    ll res = 1;
    while (exp)
    {
        if (exp & 1)
            res = res * base % mod;
        base = base * base % mod;
        exp >>= 1;
    }
    return res;
}
ll cal(ll num, ll ti)
{
    if (num == 1)
        return 1;
    while (num != 1 && ti--)
        num = sqrt(num);
    return num;
}

class SegTre
{
private:
    struct node
    {
        int laz, l, r, cnt_sq;
        ll val;
    } tre[N << 2];

#define Lz(x) tre[x].laz
#define N(x) tre[x].cnt_sq
#define V(x) tre[x].val
#define ls(x) (x << 1)
#define rs(x) (x << 1 | 1)
#define L(x) tre[x].l
#define R(x) tre[x].r
#define mid ((l + r) >> 1)

    void push_down(int x)
    {
        if (N(x))
        {
            if (Lz(ls(x)))
                N(ls(x)) += szmax(0, N(x) - Lz(ls(x))), Lz(ls(x)) = szmax(0, Lz(ls(x)) - N(x));
            else
                N(ls(x)) += N(x);
            if (Lz(rs(x)))
                N(rs(x)) += szmax(0, N(x) - Lz(rs(x))), Lz(rs(x)) = szmax(0, Lz(rs(x)) - N(x));
            else
                N(rs(x)) += N(x);
            N(x) = 0;
        }
        if (Lz(x))
            Lz(ls(x)) += Lz(x), Lz(rs(x)) += Lz(x), Lz(x) = 0;
    }

public:
    void build(int x, int l, int r)
    {
        L(x) = l, R(x) = r;
        if (l == r)
        {
            V(x) = a[l];
            return;
        }
        build(ls(x), l, mid);
        build(rs(x), mid + 1, r);
    }
    void change_sqrt(int x, int l, int r)
    {
        if (l <= L(x) && R(x) <= r)
        {
            if (!Lz(x))
                ++N(x);
            else
                --Lz(x);
            return;
        }
        push_down(x);
        int mids = (L(x) + R(x)) >> 1;
        if (l <= mids)
            change_sqrt(ls(x), l, r);
        if (r > mids)
            change_sqrt(rs(x), l, r);
    }
    void change_pow(int x, int l, int r)
    {
        if (l <= L(x) && R(x) <= r)
        {
            ++Lz(x);
            return;
        }
        push_down(x);
        int mids = (L(x) + R(x)) >> 1;
        if (l <= mids)
            change_pow(ls(x), l, r);
        if (r > mids)
            change_pow(rs(x), l, r);
    }
    ll ask(int x)
    {
        if (L(x) == R(x))
        {
            if (N(x))
                V(x) = cal(V(x), N(x));
            if (Lz(x) > 0)
                return qpow(V(x), qpow(2, Lz(x), 998244352));
            return V(x);
        }
        push_down(x);
        return (ask(ls(x)) + ask(rs(x))) % 998244353;
    }
} TRE;
int main()
{
    ios::sync_with_stdio(0), cin.tie(0), cout.tie(0);
    int n, m;
    cin >> n >> m;
    for (int i = 1; i <= n; ++i)
        cin >> a[i];
    TRE.build(1, 1, n);
    while (m--)
    {
        int op, l, r;
        cin >> op >> l >> r;
        if (op == 1)
            TRE.change_sqrt(1, l, r);
        else
            TRE.change_pow(1, l, r);
    }
    cout << TRE.ask(1) << '\n';
    return 0;
}
posted @ 2025-10-24 20:32  azaa414  阅读(1)  评论(0)    收藏  举报