HDU 6967. I love data structure题解
HDU 6967. I love data structure
题意:
维护一个向量序列\(\{(a_n,b_n)\}\),支持以下操作:
- 操作1.
1 tag l r x
若tag
为\(0\),则给所有\(i\in[l,r]\)上的\(a_i\)增加\(x\)
若tag
为\(1\),则给所有\(i\in[l,r]\)上的\(b_i\)增加\(x\) - 操作2.
2 l r
将所有\(i\in[l,r]\)上的\(a_i\),变成\(3a_i+2b_i\)
将所有\(i\in[l,r]\)上的\(b_i\),变成\(3a_i-2b_i\) - 操作3.
3 l r
将所有\(i\in[l,r]\)上的\(a_i\),变成\(b_i\)
将所有\(i\in[l,r]\)上的\(b_i\),变成\(a_i\) - 操作4.
4 l r
查询\(\sum\limits_{i=l}^ra_ib_i\)
分析:
显然是用线段树进行维护。
上面的修改操作可以归结为矩阵乘法和向量加法,所以需要矩阵乘法标记mul_tag
(记作矩阵\(A\))和向量加法标记add_tag
(记作向量\(B\))。
考虑修改标记如何下传:
整体效果需要将\(Ax+B\)变成\(A_1(Ax+B)+B_1\)
对于询问,考虑要维护哪些量:
设\( A=\begin{bmatrix}c_{11}&c_{12}\\c_{21}&c_{22}\end{bmatrix},B=\begin{bmatrix}d_1\\d_2\end{bmatrix},x=\begin{bmatrix}a\\b\end{bmatrix}\)
由于询问的是\(\sum ab\),肯定要维护\(\sum ab\)
考虑经过\(A\)和\(B\)的修改后\(\sum ab\)会变成什么。显然会变成
尝试展开,可以预见有5项需要维护,分别为\(\sum a^2,\sum b^2,\sum ab, \sum a, \sum b\)
于是需要分别考虑经过\(A\)和\(B\)的修改后\(\sum a^2,\sum b^2, \sum a, \sum b\)会变成什么。
发现这样考虑不仅计算量很大,编码也过于繁琐,为了简单,不妨将两个标记合并成一个标记
维护的向量序列变成一个三维向量
既然询问的是\(\sum ab\),干脆直接维护
经过\(C\)的变换后会变成什么
这样在代码上写起来就舒服了(代价是常数会非常大,得花点心思卡常,尤其是矩阵乘法,或者可以吸臭氧)
代码:
一共两份代码,
第一份不合并标记,写起来有点蛋疼,且在hdu上实测开不开臭氧几乎没区别,都是1800ms。
第二份合并标记,在hdu上开了臭氧2800ms,不开臭氧4400ms(TLE的边缘)。
第一份:
#include <cstdio>
#include <cstring>
#include <vector>
#define now nodes[rt]
#define ls nodes[rt << 1]
#define rs nodes[rt << 1 | 1]
using namespace std;
typedef long long Lint;
const int mod = 1e9 + 7;
const int maxn = 2e5 + 10;
struct Matrix {
int a[2][2];
};
struct Vector {
int a[2];
};
Vector operator+(const Vector& a, const Vector& b) {
Vector res;
res.a[0] = (a.a[0] + b.a[0]) % mod;
res.a[1] = (a.a[1] + b.a[1]) % mod;
return res;
}
Matrix operator*(const Matrix& a, const Matrix& b) {
Matrix res;
res.a[0][0] =
((Lint)a.a[0][0] * b.a[0][0] + (Lint)a.a[0][1] * b.a[1][0]) % mod;
res.a[0][1] =
((Lint)a.a[0][0] * b.a[0][1] + (Lint)a.a[0][1] * b.a[1][1]) % mod;
res.a[1][0] =
((Lint)a.a[1][0] * b.a[0][0] + (Lint)a.a[1][1] * b.a[1][0]) % mod;
res.a[1][1] =
((Lint)a.a[1][0] * b.a[0][1] + (Lint)a.a[1][1] * b.a[1][1]) % mod;
return res;
}
Vector operator*(const Matrix& a, const Vector& b) {
Vector res;
res.a[0] = ((Lint)a.a[0][0] * b.a[0] + (Lint)a.a[0][1] * b.a[1]) % mod;
res.a[1] = ((Lint)a.a[1][0] * b.a[0] + (Lint)a.a[1][1] * b.a[1]) % mod;
return res;
}
const Matrix op2 = {3, 2, 3, mod - 2};
const Matrix op3 = {0, 1, 1, 0};
struct Node {
Matrix mul_tag;
Vector add_tag;
int a_2, b_2, ab, a, b;
int l, r;
};
struct SegTree {
Node nodes[maxn << 2];
void pushup(int rt) {
now.a_2 = (ls.a_2 + rs.a_2) % mod;
now.b_2 = (ls.b_2 + rs.b_2) % mod;
now.ab = (ls.ab + rs.ab) % mod;
now.a = (ls.a + rs.a) % mod;
now.b = (ls.b + rs.b) % mod;
}
void update_node(const Matrix& op, Node& s) {
s.mul_tag = op * s.mul_tag;
s.add_tag = op * s.add_tag;
int a_2 = s.a_2, b_2 = s.b_2, ab = s.ab, a = s.a, b = s.b;
int c11 = op.a[0][0], c12 = op.a[0][1];
int c21 = op.a[1][0], c22 = op.a[1][1];
s.a_2 = ((Lint)c11 * c11 % mod * a_2 + 2LL * c11 * c12 % mod * ab +
(Lint)c12 * c12 % mod * b_2) %
mod;
s.b_2 = ((Lint)c21 * c21 % mod * a_2 + 2LL * c21 * c22 % mod * ab +
(Lint)c22 * c22 % mod * b_2) %
mod;
s.ab = ((Lint)c11 * c21 % mod * a_2 +
((Lint)c11 * c22 + (Lint)c12 * c21) % mod * ab +
(Lint)c12 * c22 % mod * b_2) %
mod;
s.a = ((Lint)c11 * a + (Lint)c12 * b) % mod;
s.b = ((Lint)c21 * a + (Lint)c22 * b) % mod;
}
void update_node(const Vector& op, Node& s) {
s.add_tag = op + s.add_tag;
int a_2 = s.a_2, b_2 = s.b_2, ab = s.ab, a = s.a, b = s.b;
int d1 = op.a[0], d2 = op.a[1];
int len = s.r - s.l + 1;
s.a_2 = (a_2 + 2LL * a * d1 + (Lint)d1 * d1 % mod * len) % mod;
s.b_2 = (b_2 + 2LL * b * d2 + (Lint)d2 * d2 % mod * len) % mod;
s.ab = (ab + (Lint)d2 * a + (Lint)d1 * b + (Lint)d1 * d2 % mod * len) %
mod;
s.a = (a + (Lint)d1 * len) % mod;
s.b = (b + (Lint)d2 * len) % mod;
}
void pushdown(int rt) {
update_node(now.mul_tag, ls);
update_node(now.mul_tag, rs);
now.mul_tag = {1, 0, 0, 1};
update_node(now.add_tag, ls);
update_node(now.add_tag, rs);
now.add_tag = {0, 0};
}
void build(int rt, int l, int r, int* a, int* b) {
now.mul_tag = {1, 0, 0, 1};
now.add_tag = {0, 0};
now.l = l, now.r = r;
if (l == r) {
now.a = a[l], now.b = b[l];
now.a_2 = (Lint)a[l] * a[l] % mod;
now.b_2 = (Lint)b[l] * b[l] % mod;
now.ab = (Lint)a[l] * b[l] % mod;
return;
}
int mid = l + r >> 1;
build(rt << 1, l, mid, a, b);
build(rt << 1 | 1, mid + 1, r, a, b);
pushup(rt);
}
template <class T>
void update(int rt, int L, int R, const T& op) {
int l = now.l, r = now.r;
if (L <= l && r <= R) {
update_node(op, now);
return;
}
pushdown(rt);
int mid = l + r >> 1;
if (L <= mid) update(rt << 1, L, R, op);
if (R > mid) update(rt << 1 | 1, L, R, op);
pushup(rt);
}
int query(int rt, int L, int R) {
int l = now.l, r = now.r;
if (L <= l && r <= R) return now.ab;
pushdown(rt);
int res = 0;
int mid = l + r >> 1;
if (L <= mid) res = (res + query(rt << 1, L, R)) % mod;
if (R > mid) res = (res + query(rt << 1 | 1, L, R)) % mod;
return res;
}
} seg;
int a[maxn], b[maxn];
int main() {
int n;
scanf("%d", &n);
for (int i = 1; i <= n; i++) {
scanf("%d%d", a + i, b + i);
}
seg.build(1, 1, n, a, b);
int m;
scanf("%d", &m);
while (m--) {
int op, tag, l, r;
scanf("%d", &op);
if (op == 1) {
int x;
scanf("%d%d%d%d", &tag, &l, &r, &x);
seg.update(1, l, r, tag ? (Vector){0, x} : (Vector){x, 0});
} else if (op == 2) {
scanf("%d%d", &l, &r);
seg.update(1, l, r, op2);
} else if (op == 3) {
scanf("%d%d", &l, &r);
seg.update(1, l, r, op3);
} else {
scanf("%d%d", &l, &r);
int res = seg.query(1, l, r);
printf("%d\n", res);
}
}
return 0;
}
第二份:
#include <cstdio>
#include <cstring>
#include <vector>
#define now nodes[rt]
#define ls nodes[rt << 1]
#define rs nodes[rt << 1 | 1]
using namespace std;
typedef long long Lint;
const Lint mod = 1e9 + 7;
const int maxn = 2e5 + 10;
struct Matrix {
Lint a[3][3];
Matrix operator*(const Matrix& rhs) const {
Matrix res = {0, 0, 0, 0, 0, 0, 0, 0, 0};
for (int i = 0; i < 3; i++)
for (int j = 0; j < 3; j++) {
for (int k = 0; k < 3; k++)
res.a[i][j] += a[i][k] * rhs.a[k][j];
res.a[i][j] %= mod;
}
return res;
}
Matrix operator+(const Matrix& rhs) const {
Matrix res = {0, 0, 0, 0, 0, 0, 0, 0, 0};
for (int i = 0; i < 3; i++)
for (int j = 0; j < 3; j++)
res.a[i][j] = (a[i][j] + rhs.a[i][j]) % mod;
return res;
}
Matrix T() const {
Matrix res;
for (int i = 0; i < 3; i++)
for (int j = 0; j < 3; j++) res.a[i][j] = a[j][i];
return res;
}
};
const Matrix E = {1, 0, 0, 0, 1, 0, 0, 0, 1};
const Matrix op2 = {3, 2, 0, 3, mod - 2, 0, 0, 0, 1};
const Matrix op3 = {0, 1, 0, 1, 0, 0, 0, 0, 1};
struct Node {
Matrix tag, num;
};
struct SegTree {
Node nodes[maxn << 2];
void pushup(int rt) { now.num = ls.num + rs.num; }
void update_node(int rt, const Matrix& op) {
now.tag = op * now.tag;
now.num = op * now.num * op.T();
}
void pushdown(int rt) {
update_node(rt << 1, now.tag);
update_node(rt << 1 | 1, now.tag);
now.tag = E;
}
void build(int rt, int l, int r, Lint* a, Lint* b) {
now.tag = E;
if (l == r) {
now.num.a[0][0] = a[l] * a[l] % mod;
now.num.a[1][1] = b[l] * b[l] % mod;
now.num.a[2][2] = 1;
now.num.a[0][1] = now.num.a[1][0] = a[l] * b[l] % mod;
now.num.a[0][2] = now.num.a[2][0] = a[l];
now.num.a[1][2] = now.num.a[2][1] = b[l];
return;
}
int mid = l + r >> 1;
build(rt << 1, l, mid, a, b);
build(rt << 1 | 1, mid + 1, r, a, b);
pushup(rt);
}
void update(int rt, int l, int r, int L, int R, const Matrix& op) {
if (L <= l && r <= R) {
update_node(rt, op);
return;
}
pushdown(rt);
int mid = l + r >> 1;
if (L <= mid) update(rt << 1, l, mid, L, R, op);
if (R > mid) update(rt << 1 | 1, mid + 1, r, L, R, op);
pushup(rt);
}
Lint query(int rt, int l, int r, int L, int R) {
if (L <= l && r <= R) return now.num.a[0][1];
pushdown(rt);
Lint res = 0;
int mid = l + r >> 1;
if (L <= mid) res = (res + query(rt << 1, l, mid, L, R)) % mod;
if (R > mid) res = (res + query(rt << 1 | 1, mid + 1, r, L, R)) % mod;
return res;
}
} seg;
Lint a[maxn], b[maxn];
int main() {
int n;
scanf("%d", &n);
for (int i = 1; i <= n; i++) {
scanf("%lld%lld", a + i, b + i);
}
seg.build(1, 1, n, a, b);
int m;
scanf("%d", &m);
while (m--) {
int op, tag, l, r;
scanf("%d", &op);
if (op == 1) {
int x;
scanf("%d%d%d%d", &tag, &l, &r, &x);
Matrix op = tag ? (Matrix){1, 0, 0, 0, 1, x, 0, 0, 1}
: (Matrix){1, 0, x, 0, 1, 0, 0, 0, 1};
seg.update(1, 1, n, l, r, op);
} else if (op == 2) {
scanf("%d%d", &l, &r);
seg.update(1, 1, n, l, r, op2);
} else if (op == 3) {
scanf("%d%d", &l, &r);
seg.update(1, 1, n, l, r, op3);
} else {
scanf("%d%d", &l, &r);
Lint res = seg.query(1, 1, n, l, r);
printf("%lld\n", res);
}
}
return 0;
}