HDU 6967. I love data structure题解

HDU 6967. I love data structure

题意:

维护一个向量序列\(\{(a_n,b_n)\}\),支持以下操作:

  • 操作1. 1 tag l r x
    tag\(0\),则给所有\(i\in[l,r]\)上的\(a_i\)增加\(x\)
    tag\(1\),则给所有\(i\in[l,r]\)上的\(b_i\)增加\(x\)
  • 操作2. 2 l r
    将所有\(i\in[l,r]\)上的\(a_i\),变成\(3a_i+2b_i\)
    将所有\(i\in[l,r]\)上的\(b_i\),变成\(3a_i-2b_i\)
  • 操作3. 3 l r
    将所有\(i\in[l,r]\)上的\(a_i\),变成\(b_i\)
    将所有\(i\in[l,r]\)上的\(b_i\),变成\(a_i\)
  • 操作4. 4 l r
    查询\(\sum\limits_{i=l}^ra_ib_i\)

分析:

显然是用线段树进行维护。

上面的修改操作可以归结为矩阵乘法和向量加法,所以需要矩阵乘法标记mul_tag(记作矩阵\(A\))和向量加法标记add_tag(记作向量\(B\))。

考虑修改标记如何下传:
整体效果需要将\(Ax+B\)变成\(A_1(Ax+B)+B_1\)

对于询问,考虑要维护哪些量:

\( A=\begin{bmatrix}c_{11}&c_{12}\\c_{21}&c_{22}\end{bmatrix},B=\begin{bmatrix}d_1\\d_2\end{bmatrix},x=\begin{bmatrix}a\\b\end{bmatrix}\)

由于询问的是\(\sum ab\),肯定要维护\(\sum ab\)
考虑经过\(A\)\(B\)的修改后\(\sum ab\)会变成什么。显然会变成

\[\sum(c_{11}a+c_{12}b+d_1)(c_{21}a+c_{22}b+d_2) \]

尝试展开,可以预见有5项需要维护,分别为\(\sum a^2,\sum b^2,\sum ab, \sum a, \sum b\)

于是需要分别考虑经过\(A\)\(B\)的修改后\(\sum a^2,\sum b^2, \sum a, \sum b\)会变成什么。

发现这样考虑不仅计算量很大,编码也过于繁琐,为了简单,不妨将两个标记合并成一个标记

\[\begin{aligned} C=\begin{bmatrix} c_{11}&c_{12}&d_1\\ c_{21}&c_{22}&d_2\\ 0&0&1 \end{bmatrix} \end{aligned} \]

维护的向量序列变成一个三维向量

\[\begin{aligned} x=\begin{bmatrix} a\\ b\\ 1 \end{bmatrix} \end{aligned} \]

既然询问的是\(\sum ab\),干脆直接维护

\[\begin{aligned} xx^T=\begin{bmatrix} a^2&ab&a\\ ab&b^2&b\\ a&b&1 \end{bmatrix} \end{aligned} \]

经过\(C\)的变换后会变成什么

\[\begin{aligned} (Cx)(Cx)^T=C(xx^T)C^T \end{aligned} \]

这样在代码上写起来就舒服了(代价是常数会非常大,得花点心思卡常,尤其是矩阵乘法,或者可以吸臭氧

代码:

一共两份代码,

第一份不合并标记,写起来有点蛋疼,且在hdu上实测开不开臭氧几乎没区别,都是1800ms。

第二份合并标记,在hdu上开了臭氧2800ms,不开臭氧4400ms(TLE的边缘)。

第一份:

#include <cstdio>
#include <cstring>
#include <vector>
#define now nodes[rt]
#define ls nodes[rt << 1]
#define rs nodes[rt << 1 | 1]
using namespace std;
typedef long long Lint;
const int mod = 1e9 + 7;
const int maxn = 2e5 + 10;
struct Matrix {
    int a[2][2];
};
struct Vector {
    int a[2];
};
Vector operator+(const Vector& a, const Vector& b) {
    Vector res;
    res.a[0] = (a.a[0] + b.a[0]) % mod;
    res.a[1] = (a.a[1] + b.a[1]) % mod;
    return res;
}
Matrix operator*(const Matrix& a, const Matrix& b) {
    Matrix res;
    res.a[0][0] =
        ((Lint)a.a[0][0] * b.a[0][0] + (Lint)a.a[0][1] * b.a[1][0]) % mod;
    res.a[0][1] =
        ((Lint)a.a[0][0] * b.a[0][1] + (Lint)a.a[0][1] * b.a[1][1]) % mod;
    res.a[1][0] =
        ((Lint)a.a[1][0] * b.a[0][0] + (Lint)a.a[1][1] * b.a[1][0]) % mod;
    res.a[1][1] =
        ((Lint)a.a[1][0] * b.a[0][1] + (Lint)a.a[1][1] * b.a[1][1]) % mod;
    return res;
}
Vector operator*(const Matrix& a, const Vector& b) {
    Vector res;
    res.a[0] = ((Lint)a.a[0][0] * b.a[0] + (Lint)a.a[0][1] * b.a[1]) % mod;
    res.a[1] = ((Lint)a.a[1][0] * b.a[0] + (Lint)a.a[1][1] * b.a[1]) % mod;
    return res;
}
const Matrix op2 = {3, 2, 3, mod - 2};
const Matrix op3 = {0, 1, 1, 0};
struct Node {
    Matrix mul_tag;
    Vector add_tag;
    int a_2, b_2, ab, a, b;
    int l, r;
};
struct SegTree {
    Node nodes[maxn << 2];
    void pushup(int rt) {
        now.a_2 = (ls.a_2 + rs.a_2) % mod;
        now.b_2 = (ls.b_2 + rs.b_2) % mod;
        now.ab = (ls.ab + rs.ab) % mod;
        now.a = (ls.a + rs.a) % mod;
        now.b = (ls.b + rs.b) % mod;
    }
    void update_node(const Matrix& op, Node& s) {
        s.mul_tag = op * s.mul_tag;
        s.add_tag = op * s.add_tag;
        int a_2 = s.a_2, b_2 = s.b_2, ab = s.ab, a = s.a, b = s.b;
        int c11 = op.a[0][0], c12 = op.a[0][1];
        int c21 = op.a[1][0], c22 = op.a[1][1];
        s.a_2 = ((Lint)c11 * c11 % mod * a_2 + 2LL * c11 * c12 % mod * ab +
                 (Lint)c12 * c12 % mod * b_2) %
                mod;
        s.b_2 = ((Lint)c21 * c21 % mod * a_2 + 2LL * c21 * c22 % mod * ab +
                 (Lint)c22 * c22 % mod * b_2) %
                mod;
        s.ab = ((Lint)c11 * c21 % mod * a_2 +
                ((Lint)c11 * c22 + (Lint)c12 * c21) % mod * ab +
                (Lint)c12 * c22 % mod * b_2) %
               mod;
        s.a = ((Lint)c11 * a + (Lint)c12 * b) % mod;
        s.b = ((Lint)c21 * a + (Lint)c22 * b) % mod;
    }
    void update_node(const Vector& op, Node& s) {
        s.add_tag = op + s.add_tag;
        int a_2 = s.a_2, b_2 = s.b_2, ab = s.ab, a = s.a, b = s.b;
        int d1 = op.a[0], d2 = op.a[1];
        int len = s.r - s.l + 1;
        s.a_2 = (a_2 + 2LL * a * d1 + (Lint)d1 * d1 % mod * len) % mod;
        s.b_2 = (b_2 + 2LL * b * d2 + (Lint)d2 * d2 % mod * len) % mod;
        s.ab = (ab + (Lint)d2 * a + (Lint)d1 * b + (Lint)d1 * d2 % mod * len) %
               mod;
        s.a = (a + (Lint)d1 * len) % mod;
        s.b = (b + (Lint)d2 * len) % mod;
    }
    void pushdown(int rt) {
        update_node(now.mul_tag, ls);
        update_node(now.mul_tag, rs);
        now.mul_tag = {1, 0, 0, 1};
        update_node(now.add_tag, ls);
        update_node(now.add_tag, rs);
        now.add_tag = {0, 0};
    }
    void build(int rt, int l, int r, int* a, int* b) {
        now.mul_tag = {1, 0, 0, 1};
        now.add_tag = {0, 0};
        now.l = l, now.r = r;
        if (l == r) {
            now.a = a[l], now.b = b[l];
            now.a_2 = (Lint)a[l] * a[l] % mod;
            now.b_2 = (Lint)b[l] * b[l] % mod;
            now.ab = (Lint)a[l] * b[l] % mod;
            return;
        }
        int mid = l + r >> 1;
        build(rt << 1, l, mid, a, b);
        build(rt << 1 | 1, mid + 1, r, a, b);
        pushup(rt);
    }
    template <class T>
    void update(int rt, int L, int R, const T& op) {
        int l = now.l, r = now.r;
        if (L <= l && r <= R) {
            update_node(op, now);
            return;
        }
        pushdown(rt);
        int mid = l + r >> 1;
        if (L <= mid) update(rt << 1, L, R, op);
        if (R > mid) update(rt << 1 | 1, L, R, op);
        pushup(rt);
    }
    int query(int rt, int L, int R) {
        int l = now.l, r = now.r;
        if (L <= l && r <= R) return now.ab;
        pushdown(rt);
        int res = 0;
        int mid = l + r >> 1;
        if (L <= mid) res = (res + query(rt << 1, L, R)) % mod;
        if (R > mid) res = (res + query(rt << 1 | 1, L, R)) % mod;
        return res;
    }
} seg;
int a[maxn], b[maxn];
int main() {
    int n;
    scanf("%d", &n);
    for (int i = 1; i <= n; i++) {
        scanf("%d%d", a + i, b + i);
    }
    seg.build(1, 1, n, a, b);
    int m;
    scanf("%d", &m);
    while (m--) {
        int op, tag, l, r;
        scanf("%d", &op);
        if (op == 1) {
            int x;
            scanf("%d%d%d%d", &tag, &l, &r, &x);
            seg.update(1, l, r, tag ? (Vector){0, x} : (Vector){x, 0});
        } else if (op == 2) {
            scanf("%d%d", &l, &r);
            seg.update(1, l, r, op2);
        } else if (op == 3) {
            scanf("%d%d", &l, &r);
            seg.update(1, l, r, op3);
        } else {
            scanf("%d%d", &l, &r);
            int res = seg.query(1, l, r);
            printf("%d\n", res);
        }
    }
    return 0;
}

第二份:

#include <cstdio>
#include <cstring>
#include <vector>
#define now nodes[rt]
#define ls nodes[rt << 1]
#define rs nodes[rt << 1 | 1]
using namespace std;
typedef long long Lint;
const Lint mod = 1e9 + 7;
const int maxn = 2e5 + 10;
struct Matrix {
    Lint a[3][3];
    Matrix operator*(const Matrix& rhs) const {
        Matrix res = {0, 0, 0, 0, 0, 0, 0, 0, 0};
        for (int i = 0; i < 3; i++)
            for (int j = 0; j < 3; j++) {
                for (int k = 0; k < 3; k++)
                    res.a[i][j] += a[i][k] * rhs.a[k][j];
                res.a[i][j] %= mod;
            }
        return res;
    }
    Matrix operator+(const Matrix& rhs) const {
        Matrix res = {0, 0, 0, 0, 0, 0, 0, 0, 0};
        for (int i = 0; i < 3; i++)
            for (int j = 0; j < 3; j++)
                res.a[i][j] = (a[i][j] + rhs.a[i][j]) % mod;
        return res;
    }
    Matrix T() const {
        Matrix res;
        for (int i = 0; i < 3; i++)
            for (int j = 0; j < 3; j++) res.a[i][j] = a[j][i];
        return res;
    }
};
const Matrix E = {1, 0, 0, 0, 1, 0, 0, 0, 1};
const Matrix op2 = {3, 2, 0, 3, mod - 2, 0, 0, 0, 1};
const Matrix op3 = {0, 1, 0, 1, 0, 0, 0, 0, 1};
struct Node {
    Matrix tag, num;
};
struct SegTree {
    Node nodes[maxn << 2];
    void pushup(int rt) { now.num = ls.num + rs.num; }
    void update_node(int rt, const Matrix& op) {
        now.tag = op * now.tag;
        now.num = op * now.num * op.T();
    }
    void pushdown(int rt) {
        update_node(rt << 1, now.tag);
        update_node(rt << 1 | 1, now.tag);
        now.tag = E;
    }
    void build(int rt, int l, int r, Lint* a, Lint* b) {
        now.tag = E;
        if (l == r) {
            now.num.a[0][0] = a[l] * a[l] % mod;
            now.num.a[1][1] = b[l] * b[l] % mod;
            now.num.a[2][2] = 1;
            now.num.a[0][1] = now.num.a[1][0] = a[l] * b[l] % mod;
            now.num.a[0][2] = now.num.a[2][0] = a[l];
            now.num.a[1][2] = now.num.a[2][1] = b[l];
            return;
        }
        int mid = l + r >> 1;
        build(rt << 1, l, mid, a, b);
        build(rt << 1 | 1, mid + 1, r, a, b);
        pushup(rt);
    }
    void update(int rt, int l, int r, int L, int R, const Matrix& op) {
        if (L <= l && r <= R) {
            update_node(rt, op);
            return;
        }
        pushdown(rt);
        int mid = l + r >> 1;
        if (L <= mid) update(rt << 1, l, mid, L, R, op);
        if (R > mid) update(rt << 1 | 1, mid + 1, r, L, R, op);
        pushup(rt);
    }
    Lint query(int rt, int l, int r, int L, int R) {
        if (L <= l && r <= R) return now.num.a[0][1];
        pushdown(rt);
        Lint res = 0;
        int mid = l + r >> 1;
        if (L <= mid) res = (res + query(rt << 1, l, mid, L, R)) % mod;
        if (R > mid) res = (res + query(rt << 1 | 1, mid + 1, r, L, R)) % mod;
        return res;
    }
} seg;
Lint a[maxn], b[maxn];
int main() {
    int n;
    scanf("%d", &n);
    for (int i = 1; i <= n; i++) {
        scanf("%lld%lld", a + i, b + i);
    }
    seg.build(1, 1, n, a, b);
    int m;
    scanf("%d", &m);
    while (m--) {
        int op, tag, l, r;
        scanf("%d", &op);
        if (op == 1) {
            int x;
            scanf("%d%d%d%d", &tag, &l, &r, &x);
            Matrix op = tag ? (Matrix){1, 0, 0, 0, 1, x, 0, 0, 1}
                            : (Matrix){1, 0, x, 0, 1, 0, 0, 0, 1};
            seg.update(1, 1, n, l, r, op);
        } else if (op == 2) {
            scanf("%d%d", &l, &r);
            seg.update(1, 1, n, l, r, op2);
        } else if (op == 3) {
            scanf("%d%d", &l, &r);
            seg.update(1, 1, n, l, r, op3);
        } else {
            scanf("%d%d", &l, &r);
            Lint res = seg.query(1, 1, n, l, r);
            printf("%lld\n", res);
        }
    }
    return 0;
}
posted @ 2021-07-26 19:01  聆竹听风  阅读(244)  评论(0)    收藏  举报