求函数

求函数

题目描述

牛可乐有 $n$ 个一次函数,第 $i$ 个函数为 $f_i(x) = k_i \times x + b_i$。

牛可乐有 $m$ 次操作,每次操作为以下二者其一:

  • $1$ $i$ $k$ $b$ 将 $f_i(x)$ 修改为 $f_i(x) = k \times x + b$。
  • $2$ $l$ $r$ 求 $f_r\left(f_{r-1}\left(\cdots \left(f_{l+1}\left(f_l(1)\right)\right) \cdots \right)\right)$。

牛可乐当然(bu)会做啦,他想考考你——

答案对 $10^9 + 7$ 取模。

输入描述:

第一行,两个正整数 $n, m$。

第二行,$n$ 个整数 $k_1, k_2, \dots, k_n$ 。

第三行,$n$ 个整数 $b_1, b_2, \dots, b_n$。

接下来 $m$ 行,每行一个操作,格式见上。

保证 $1 \leq n, m \leq 2 \times 10^5$,$0 \leq k_i, b_i < 10^9 + 7$。

输出描述:

对于每个求值操作,输出一行一个整数,表示答案。

示例1

输入

2 3
1 1
1 0
1 2 114514 1919810
2 1 2
2 1 1

输出

2148838
2

说明

初始 $f_1(x) = x + 1$,$f_2(x) = x$

修改后 $f_2(x) = 114514x + 1919810$

查询时 $f_1(1) = 2$,$f_2(f_1(1)) = 2148838$

 

解题思路

  由于每个函数都是线性形式,即 $f_i(x) = k_i \times x + b_i$,我们可以展开复合函数 $f_r\left(f_{r-1}\left(\cdots \left(f_{l+1}\left(f_l(1)\right)\right) \cdots \right)\right)$,得到如下结果:

\begin{aligned}
&k_r k_{r-1} \cdots k_l \cdot 1 \\
&+ k_r k_{r-1} \cdots k_{l+1} \cdot b_l \\
&+ k_r k_{r-1} \cdots k_{l+2} \cdot b_{l+1} \\
&+ \cdots \\
&+ k_r k_{r-1} \cdots k_{l+i+1} \cdot b_{l+i} \\
&+ \cdots \\
&+ b_r \\
=& \prod_{i=l}^{r} k_i + \sum_{i=l}^{r} \left( b_i \cdot \prod_{j=i+1}^{r} k_j \right)
\end{aligned}

  由于涉及到区间查询,因此我们考虑使用线段树来维护上述结果。注意到整个结果可以分为两部分:左项是 $\prod\limits_{i=l}^{r} k_i$,右项是和 $\sum\limits_{i=l}^{r} \left( b_i \cdot \prod\limits_{j=i+1}^{r} k_j \right)$ 构成。对于左项,显然我们只需要维护区间 $[l,r]$ 内的 $k_i$ 的乘积,即 $p_{l \sim r} = \prod\limits_{i=l}^{r} k_i$。而对于右项,我们可以先考虑直接维护 $s_{l \sim r} = \sum\limits_{i=l}^{r} \left( b_i \cdot \prod\limits_{j=i+1}^{r} k_j \right)$。关键的问题是,当两个区间 $[l, m]$ 与 $[m+1, r]$ 要合并成一个区间 $[l,r]$ 时,怎么得到该区间的 $s_{l \sim r}$?

  将右项继续拆开,可以得到:

\begin{aligned}
&\sum_{i=l}^{r} \left( b_i \cdot \prod_{j=i+1}^{r} k_j \right) \\
= &\sum_{i=l}^{m} \left( b_i \cdot \prod_{j=i+1}^{r} k_j \right) + \sum_{i=m+1}^{r} \left( b_i \cdot \prod_{j=i+1}^{r} k_j \right) \\
= &\sum_{i=l}^{m} \left( b_i \cdot \prod_{j=i+1}^{m} k_j \cdot \prod_{j=m+1}^{r} k_j \right) + \sum_{i=m+1}^{r} \left( b_i \cdot \prod_{j=i+1}^{r} k_j \right) \\
= &\prod_{j=m+1}^{r} k_j \cdot \sum_{i=l}^{m} \left( b_i \cdot \prod_{j=i+1}^{m} k_j \right) + \sum_{i=m+1}^{r} \left( b_i \cdot \prod_{j=i+1}^{r} k_j \right)
\end{aligned}

  将上述表达式代入 $s_{l \sim m}$、$s_{m+1 \sim r}$ 和 $p_{m+1 \sim r}$,得到 $s_{l \sim r} = p_{m+1 \sim r} \cdot s_{l \sim m} + s_{m+1 \sim r}$。因此,我们可以利用线段树来维护并计算区间 $[l, r]$ 内的结果。具体来说,维护每个区间的乘积 $p_{l \sim r}$ 与 $s_{l \sim r}$,并通过合并操作来得到最终结果。

  AC 代码如下,时间复杂度为 $O(m \log{n})$:

#include <bits/stdc++.h>
using namespace std;

typedef long long LL;

const int N = 2e5 + 5, mod = 1e9 + 7;

int a[N], b[N];
struct Node {
    int l, r, p, s;
}tr[N * 4];

Node pushup(Node l, Node r) {
    int p = 1ll * l.p * r.p % mod;
    int s = (1ll * l.s * r.p % mod + r.s) % mod;
    return {l.l, r.r, p, s};
}

void build(int u, int l, int r) {
    tr[u] = {l, r};
    if (l == r) {
        tr[u].p = a[l];
        tr[u].s = b[l];
    }
    else {
        int mid = l + r >> 1;
        build(u << 1, l, mid);
        build(u << 1 | 1, mid + 1, r);
        tr[u] = pushup(tr[u << 1], tr[u << 1 | 1]);
    }
}

void modify(int u, int x, int k, int b) {
    if (tr[u].l == tr[u].r) {
        tr[u].p = k;
        tr[u].s = b;
    }
    else {
        int mid = tr[u].l + tr[u].r >> 1;
        if (x <= mid) modify(u << 1, x, k, b);
        else modify(u << 1 | 1, x, k, b);
        tr[u] = pushup(tr[u << 1], tr[u << 1 | 1]);
    }
}

Node query(int u, int l, int r) {
    if (tr[u].l >= l && tr[u].r <= r) return tr[u];
    int mid = tr[u].l + tr[u].r >> 1;
    if (r <= mid) return query(u << 1, l, r);
    if (l >= mid + 1) return query(u << 1 | 1, l, r);
    return pushup(query(u << 1, l, r), query(u << 1 | 1, l, r));
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    int n, m;
    cin >> n >> m;
    for (int i = 1; i <= n; i++) {
        cin >> a[i];
    }
    for (int i = 1; i <= n; i++) {
        cin >> b[i];
    }
    build(1, 1, n);
    while (m--) {
        int op;
        cin >> op;
        if (op == 1) {
            int x, k, b;
            cin >> x >> k >> b;
            modify(1, x, k, b);
        }
        else {
            int l, r;
            cin >> l >> r;
            Node t = query(1, l, r);
            cout << (t.p + t.s) % mod << '\n';
        }
    }
    
    return 0;
}

 

参考资料

  【题解】2020牛客寒假算法基础集训营第二场:https://ac.nowcoder.com/discuss/364961

posted @ 2025-10-23 21:29  onlyblues  阅读(8)  评论(0)    收藏  举报
Web Analytics