luogu3373 【模板】线段树2
题目大意:
已知一个数列,你需要进行下面三种操作:
1.将某区间每一个数乘上x
2.将某区间每一个数加上x
3.求出某区间每一个数的和
本线段树的标记是个二元组:add和mul,其代表将一个线段中的每一个点乘以mul再加add。设区间长度为x,原来区间和为sum。如果两个标记要叠加,标记叠加前区间上的和将是sum*mul+add,叠加后的值将是(sum*mul+add)*mul'+add'=mul*mul'*sum+add*mul'+add'。所以将mul*=mul', add=add*mul'+add'即可。
注意:
- 尽管数据关于P取模了,但是因为有数据相乘的操作,所以程序中所有的值类型都要是long long
- 宏定义ModPlus, ModMult时,如ModMult,不要写成((x%P)*(y%P))%P,应该写成(x*y)%P,否则就被卡常数了。
#include <cstdio>
#include <cstring>
#include <cassert>
using namespace std;
const int MAX_RANGE=100010, MAX_NODE = MAX_RANGE * 4;
#define LOOP(i, n) for(int i=1; i<=n; i++)
long long P, TotRange;
long long OrgData[MAX_RANGE];
struct RangeTree
{
private:
#define ModPlus(x, y) ((x)%P+(y)%P)%P
#define ModMult(x, y) ((x)%P*(y)%P)%P
#define lSon cur*2, l, mid
#define rSon cur*2+1, mid+1, r
#define Lson cur*2, sl, mid, al, ar
#define Rson cur*2+1, mid+1, sr, al, ar
struct Tag
{
long long add, mul;
Tag() {}
Tag(int m, int a):mul(m),add(a){}
void Refresh(Tag x) { mul = ModMult(mul, x.mul); add = ModMult(add, x.mul); add = ModPlus(add, x.add); }
void Clear() { add = 0; mul = 1; }
int GetSum(int sum, int l, int r) { return ModPlus(ModMult(sum, mul), ModMult(add, (r - l + 1))); }
};
Tag _tags[MAX_NODE];
long long Sum[MAX_NODE];
void PushDown(int cur, int l, int r)
{
if (_tags[cur].add != 0 || _tags[cur].mul != 1)
{
int mid = (l + r) / 2;
Sum[cur * 2] = _tags[cur].GetSum(Sum[cur * 2], l, mid);
Sum[cur * 2 + 1] = _tags[cur].GetSum(Sum[cur * 2 + 1], mid + 1, r);
_tags[cur * 2].Refresh(_tags[cur]);
_tags[cur * 2 + 1].Refresh(_tags[cur]);
_tags[cur].Clear();
}
}
void PullUp(int cur)
{
Sum[cur] = ModPlus(Sum[cur * 2], Sum[cur * 2 + 1]);
}
void Update(int cur, int sl, int sr, int al, int ar, int op, int value)
{
assert(al <= ar && sl <= sr && al <= sr && ar >= sl);
if (al <= sl && sr <= ar)
{
if (op == 1)
{
Sum[cur] = ModMult(Sum[cur], value);
_tags[cur].Refresh(Tag(value, 0));
}
else if (op == 2)
{
Sum[cur] = ModPlus(Sum[cur], (sr - sl + 1)*value);
_tags[cur].Refresh(Tag(1, value));
}
return;
}
PushDown(cur, sl, sr);
int mid = (sl + sr) / 2;
if (al <= mid)
Update(Lson, op, value);
if (ar > mid)
Update(Rson, op, value);
PullUp(cur);
}
int Query(int cur, int sl, int sr, int al, int ar)
{
assert(al <= ar && sl <= sr && al <= sr && ar >= sl);
if (al <= sl && sr <= ar)
return Sum[cur];
PushDown(cur, sl, sr);
int mid = (sl + sr) / 2, ans = 0;
if (al <= mid)
ans = ModPlus(ans, Query(Lson));
if (ar > mid)
ans = ModPlus(ans, Query(Rson));
PullUp(cur);
return ans;
}
void SetEachNode(long long *a, int cur, int l, int r)
{
_tags[cur] = Tag(1, 0);
if (l == r)
{
Sum[cur] = a[l];
return;
}
int mid = (l + r) / 2;
SetEachNode(a, lSon);
SetEachNode(a, rSon);
PullUp(cur);
}
public:
RangeTree() {}
void SetEachNode(long long *a)
{
SetEachNode(a, 1, 1, TotRange);
}
void Update(int l, int r, int op, int value)
{
Update(1, 1, TotRange, l, r, op, value);
}
long long Query(int l, int r)
{
return Query(1, 1, TotRange, l, r);
}
}g;
int main()
{
int opCnt, op, l, r, val;
scanf("%lld%d%lld", &TotRange, &opCnt, &P);
LOOP(i, TotRange)
scanf("%lld", OrgData + i);
g.SetEachNode(OrgData);
while (opCnt--)
{
scanf("%d", &op);
switch (op)
{
case 1://Mult
scanf("%d%d%d", &l, &r, &val);
g.Update(l, r, 1, val);
break;
case 2://Plus
scanf("%d%d%d", &l, &r, &val);
g.Update(l, r, 2, val);
break;
case 3://Query
scanf("%d%d", &l, &r);
printf("%lld\n", g.Query(l, r));
break;
}
}
return 0;
}

浙公网安备 33010602011771号