几种基础莫队模板
先挂一个博客,感谢此文让我弄懂莫队
1、不带修莫队:
#include <cstdio> #include <cstring> #include <cmath> #include <algorithm> #include <iostream> #include <queue> #include <vector> #include <map> #include <set> #include <cstdlib> #include <ctime> using namespace std; #define ll long long #define ull unsigned long long #define lowbit(x) ((x) & (-x)) #define For(x, i, j) for (int x = (i); x <= (j); x++) #define FOR(x, i, j) for (int x = (i); x >= (j); x--) #define ls(o) (o << 1) #define rs(o) (o << 1 | 1) #define debug(x) cout << "debug : " << x << endl; inline int read() { int x = 0, w = 1; char ch = getchar(); while (ch < '0' || ch > '9') {if (ch == '-') w = -1; ch = getchar();} while (ch >= '0' && ch <= '9') x = x * 10 + ch - '0', ch = getchar(); return x * w; } #define N 500005 int n, m, k, a[N], cnt[N], ans[N], blo, tot, pos[N]; int nowans; struct query {int l, r, id;} q[N]; bool cmp(query x, query y) { return (pos[x.l] ^ pos[y.l]) ? pos[x.l] < pos[y.l] : ((pos[x.l] & 1) ? x.r < y.r : x.r > y.r); } inline void add(int x) { nowans += (2 * cnt[a[x]] + 1); cnt[a[x]]++; } inline void del(int x) { nowans -= (2 * cnt[a[x]] - 1); cnt[a[x]]--; } int main() { n = read(); m = read(); k = read(); For(i, 1, n) a[i] = read(); blo = sqrt(n); tot = n / blo; if (tot * blo < n) tot++; For(i, 1, n) pos[i] = (i - 1) / blo + 1; For(i, 1, m) q[i].l = read(), q[i].r = read(), q[i].id = i; sort(q + 1, q + m + 1, cmp); int l = 1, r = 0; For(i, 1, m) { while (l < q[i].l) del(l++); while (l > q[i].l) add(--l); while (r < q[i].r) add(++r); while (r > q[i].r) del(r--); ans[q[i].id] = nowans; } For(i, 1, m) printf("%d\n", ans[i]); return 0; }
#include <cstdio> #include <cstring> #include <cmath> #include <algorithm> #include <iostream> #include <queue> #include <vector> #include <map> #include <set> #include <cstdlib> #include <ctime> using namespace std; #define ll long long #define ull unsigned long long #define lowbit(x) ((x) & (-x)) #define For(x, i, j) for (int x = (i); x <= (j); x++) #define FOR(x, i, j) for (int x = (i); x >= (j); x--) #define ls(o) (o << 1) #define rs(o) (o << 1 | 1) #define debug(x) cout << "debug : " << x << endl; inline int read() { int x = 0, w = 1; char ch = getchar(); while (ch < '0' || ch > '9') {if (ch == '-') w = -1; ch = getchar();} while (ch >= '0' && ch <= '9') x = x * 10 + ch - '0', ch = getchar(); return x * w; } #define N 100005 long long n, m, a[N], blo, tot, pos[N]; long long ans1[N], ans2[N], cnt[N], nowans; long long gcd(long long x, long long y) {return y == 0 ? x : gcd(y, x % y);} struct query {long long l, r, id;} q[N]; bool cmp(query x, query y) { return (pos[x.l] ^ pos[y.l]) ? pos[x.l] < pos[y.l] : ((pos[x.l] & 1) ? x.r < y.r : x.r > y.r); } inline void add(long long x) { // nowans += cnt[a[x]] * 2 + 1; nowans += (cnt[a[x]] * 2); cnt[a[x]]++; } inline void del(long long x) { // nowans -= (2 * cnt[a[x]] - 1); nowans -= (2 * cnt[a[x]] - 2); cnt[a[x]]--; } int main() { n = read(); m = read(); For(i, 1, n) a[i] = read(); blo = sqrt(n); tot = n / blo; if (tot * blo < n) tot++; For(i, 1, n) pos[i] = (i - 1) / blo + 1; For(i, 1, m) q[i].l = read(), q[i].r = read(), q[i].id = i; sort(q + 1, q + m + 1, cmp); int l = 1, r = 0; For(i, 1, m) { if (q[i].l == q[i].r) {ans1[q[i].id] = 0; ans2[q[i].id] = 1; continue;} while (l < q[i].l) del(l++); while (l > q[i].l) add(--l); while (r < q[i].r) add(++r); while (r > q[i].r) del(r--); // ans1[q[i].id] = nowans - (q[i].r - q[i].l + 1); ans1[q[i].id] = nowans; ans2[q[i].id] = (q[i].r - q[i].l + 1) * (q[i].r - q[i].l); long long GCD = gcd(ans1[q[i].id], ans2[q[i].id]); ans1[q[i].id] /= GCD; ans2[q[i].id] /= GCD; } For(i, 1, m) printf("%lld/%lld\n", ans1[i], ans2[i]); return 0; }
2、带修莫队
#include <cstdio> #include <cstring> #include <cmath> #include <algorithm> #include <iostream> #include <queue> #include <vector> #include <map> #include <set> #include <cstdlib> #include <ctime> using namespace std; #define ll long long #define ull unsigned long long #define lowbit(x) ((x) & (-x)) #define For(x, i, j) for (int x = (i); x <= (j); x++) #define FOR(x, i, j) for (int x = (i); x >= (j); x--) #define ls(o) (o << 1) #define rs(o) (o << 1 | 1) #define debug(x) cout << "debug : " << x << endl; inline int read() { int x = 0, w = 1; char ch = getchar(); while (ch < '0' || ch > '9') {if (ch == '-') w = -1; ch = getchar();} while (ch >= '0' && ch <= '9') x = x * 10 + ch - '0', ch = getchar(); return x * w; } #define V 1000005 #define N 200005 int a[N], cnt[V], ans[N], pos[N], nowans; struct query {int l, r, time, id;} q[N]; struct modify {int pos, color, las;} c[N]; int cntq, cntc, n, m, blo, tot; bool cmp(query x, query y) { return (pos[x.l] ^ pos[y.l]) ? pos[x.l] < pos[y.l] : ((pos[x.r] ^ pos[y.r]) ? pos[x.r] < pos[y.r] : x.time < y.time); } inline void add(int x) { if (cnt[a[x]] == 0) ++nowans; cnt[a[x]]++; } inline void del(int x) { cnt[a[x]]--; if (cnt[a[x]] == 0) --nowans; } int main() { n = read(); m = read(); blo = pow(n, 2.0 / 3.0); tot = ceil((double)n / blo); For(i, 1, n) pos[i] = (i - 1) / blo + 1; For(i, 1, n) a[i] = read(); For(i, 1, m) { char op[55]; scanf("%s", op); if (op[0] == 'Q') { q[++cntq].l = read(); q[cntq].r = read(); q[cntq].time = cntc; q[cntq].id = cntq; } if (op[0] == 'R') { c[++cntc].pos = read(); c[cntc].color = read(); } } sort(q + 1, q + cntq + 1, cmp); int l = 1, r = 0, time = 0; For(i, 1, cntq) { int ql = q[i].l, qr = q[i].r, qt = q[i].time; while (l < ql) del(l++); while (l > ql) add(--l); while (r < qr) add(++r); while (r > qr) del(r--); while (time < qt) { ++time; if (ql <= c[time].pos && c[time].pos <= qr) nowans -= !--cnt[a[c[time].pos]] - !cnt[c[time].color]++; swap(a[c[time].pos], c[time].color); } while (time > qt) { if (ql <= c[time].pos && c[time].pos <= qr) nowans -= !--cnt[a[c[time].pos]] - !cnt[c[time].color]++; swap(a[c[time].pos], c[time].color); --time; } ans[q[i].id] = nowans; } For(i, 1, cntq) printf("%d\n", ans[i]); return 0; }
3、树上带修莫队
#include <cstdio> #include <cstring> #include <cmath> #include <algorithm> #include <iostream> #include <queue> #include <vector> #include <map> #include <set> #include <cstdlib> #include <ctime> using namespace std; #define ll long long #define ull unsigned long long #define lowbit(x) ((x) & (-x)) #define For(x, i, j) for (int x = (i); x <= (j); x++) #define FOR(x, i, j) for (int x = (i); x >= (j); x--) #define ls(o) (o << 1) #define rs(o) (o << 1 | 1) #define debug(x) cout << "debug : " << x << endl; inline int read() { int x = 0, w = 1; char ch = getchar(); while (ch < '0' || ch > '9') {if (ch == '-') w = -1; ch = getchar();} while (ch >= '0' && ch <= '9') x = x * 10 + ch - '0', ch = getchar(); return x * w; } #define N 200005 int cnt[N], a[N], pos[N], n, m, Q, ncnt, blo, tot; int w[N], v[N], ccnt, qcnt; int val[N], fa[N][35], dep[N], head[N], ecnt; int fir[N], las[N], vis[N]; ll nowans, ans[N]; struct egde { int ver, Next; } e[N]; void insert(int x, int y) { e[++ecnt].Next = head[x]; head[x] = ecnt; e[ecnt].ver = y; } struct query { int l, r, id, lca, t; } q[N]; struct change { int pos, val; } ch[N]; bool cmp(query x, query y) { return (pos[x.l] ^ pos[y.l]) ? pos[x.l] < pos[y.l] : ((pos[x.r] ^ pos[y.r]) ? pos[x.r] < pos[y.r] : x.t < y.t); } void dfs(int x) { a[++ncnt] = x; fir[x] = ncnt; for (int i = head[x]; i; i = e[i].Next) { int y = e[i].ver; if (dep[y]) continue; dep[y] = dep[x] + 1; fa[y][0] = x; for (int i = 1; (1 << i) <= dep[y]; i++) fa[y][i] = fa[fa[y][i - 1]][i - 1]; dfs(y); } a[++ncnt] = x; las[x] = ncnt; } int lca(int x, int y) { if (dep[x] < dep[y]) swap(x, y); for (int i = 20; i >= 0; i--) if (dep[fa[x][i]] >= dep[y]) x = fa[x][i]; if (x == y) return x; for (int i = 20; i >= 0; i--) if (fa[x][i] != fa[y][i]) x = fa[x][i], y = fa[y][i]; return fa[x][0]; } inline void add(int x) { nowans += 1LL * v[val[x]] * w[++cnt[val[x]]]; } inline void del(int x) { nowans -= 1LL * v[val[x]] * w[cnt[val[x]]--]; } inline void work(int x) { vis[x] ? del(x) : add(x); vis[x] ^= 1; } void modify(int x) { if (vis[ch[x].pos]) { work(ch[x].pos); swap(val[ch[x].pos], ch[x].val); work(ch[x].pos); } else swap(val[ch[x].pos], ch[x].val); } int main() { n = read(); m = read(); Q = read(); For(i, 1, m) v[i] = read(); For(i, 1, n) w[i] = read(); For(i, 1, n - 1) { int x = read(), y = read(); insert(x, y); insert(y, x); } For(i, 1, n) val[i] = read(); dep[1] = 1; dfs(1); blo = pow(ncnt, 2.0 / 3.0); tot = ceil((double)ncnt / blo); For(i, 1, ncnt) pos[i] = (i - 1) / blo + 1; For(i, 1, Q) { int op = read(), x = read(), y = read(); if (op) { int LCA = lca(x, y); q[++qcnt].t = ccnt; q[qcnt].id = qcnt; if (fir[x] > fir[y]) swap(x, y); if (x == LCA) q[qcnt].l = fir[x], q[qcnt].r = fir[y]; else q[qcnt].l = las[x], q[qcnt].r = fir[y], q[qcnt].lca = LCA; } else { ch[++ccnt].pos = x; ch[ccnt].val = y; } } sort(q + 1, q + qcnt + 1, cmp); int l = 1, r = 0, t = 0; For(i, 1, qcnt) { int ql = q[i].l, qr = q[i].r, qt = q[i].t, qlca = q[i].lca; while (l < ql) work(a[l++]); while (l > ql) work(a[--l]); while (r < qr) work(a[++r]); while (r > qr) work(a[r--]); while (t < qt) modify(++t); while (t > qt) modify(t--); if (qlca) work(qlca); ans[q[i].id] = nowans; if (qlca) work(qlca); } For(i, 1, qcnt) printf("%lld\n", ans[i]); return 0; }
4、回滚莫队
#include <cstdio> #include <cstring> #include <cmath> #include <algorithm> #include <iostream> #include <queue> #include <vector> #include <map> #include <set> #include <cstdlib> #include <ctime> using namespace std; #define ll long long #define ull unsigned long long #define lowbit(x) ((x) & (-x)) #define For(x, i, j) for (int x = (i); x <= (j); x++) #define FOR(x, i, j) for (int x = (i); x >= (j); x--) #define ls(o) (o << 1) #define rs(o) (o << 1 | 1) #define debug(x) cout << "debug : " << x << endl; inline int read() { int x = 0, w = 1; char ch = getchar(); while (ch < '0' || ch > '9') {if (ch == '-') w = -1; ch = getchar();} while (ch >= '0' && ch <= '9') x = x * 10 + ch - '0', ch = getchar(); return x * w; } #define N 100005 int a[N], typ[N], cnt[N], cnt2[N], pos[N]; int lb[N], rb[N]; int inp[N]; ll ans[N]; struct query { int l, r, id; } q[N]; int n, m, blo, tot; int cmp(query a, query b) { return (pos[a.l] ^ pos[b.l]) ? pos[a.l] < pos[b.l] : a.r < b.r; } int main() { n = read(); m = read(); blo = sqrt(n); tot = ceil((double)n / blo); For(i, 1, tot) { lb[i] = blo * (i - 1) + 1; rb[i] = blo * i; For(j, lb[i], rb[i]) pos[j] = i; } rb[tot] = n; For(i, 1, n) inp[i] = a[i] = read(); sort(inp + 1, inp + n + 1); int len = unique(inp + 1, inp + n + 1) - (inp + 1); For(i, 1, n) typ[i] = lower_bound(inp + 1, inp + len + 1, a[i]) - inp; For(i, 1, m) { q[i].l = read(); q[i].r = read(); q[i].id = i; } sort(q + 1, q + m + 1, cmp); int i = 1; For(k, 0, tot) { int l = rb[k] + 1, r = rb[k]; ll now = 0; memset(cnt, 0, sizeof(cnt)); for (; pos[q[i].l] == k; i++) { int ql = q[i].l, qr = q[i].r; ll tmp; if (pos[ql] == pos[qr]) { tmp = 0; For(j, ql, qr) cnt2[typ[j]] = 0; For(j, ql, qr) { cnt2[typ[j]]++; tmp = max(tmp, 1LL * cnt2[typ[j]] * a[j]); } ans[q[i].id] = tmp; continue; } while (r < qr) { cnt[typ[++r]]++; now = max(now, 1LL * cnt[typ[r]] * a[r]); } tmp = now; while (l > ql) { cnt[typ[--l]]++; now = max(now, 1LL * cnt[typ[l]] * a[l]); } ans[q[i].id] = now; while (l < rb[k] + 1) --cnt[typ[l++]]; now = tmp; } } For(i, 1, m) printf("%lld\n", ans[i]); return 0; }
#include <cstdio> #include <cstring> #include <cmath> #include <algorithm> #include <iostream> #include <queue> #include <vector> #include <map> #include <set> #include <cstdlib> #include <ctime> using namespace std; #define ll long long #define ull unsigned long long #define lowbit(x) ((x) & (-x)) #define For(x, i, j) for (int x = (i); x <= (j); x++) #define FOR(x, i, j) for (int x = (i); x >= (j); x--) #define ls(o) (o << 1) #define rs(o) (o << 1 | 1) #define debug(x) cout << "debug : " << x << endl; inline int read() { int x = 0, w = 1; char ch = getchar(); while (ch < '0' || ch > '9') {if (ch == '-') w = -1; ch = getchar();} while (ch >= '0' && ch <= '9') x = x * 10 + ch - '0', ch = getchar(); return x * w; } #define N 200005 int n, m, blo, tot, pos[N], lb[N], rb[N], a[N], inp[N], ans[N]; struct query { int l, r, id; } q[N]; int cmp(query a, query b) { return (pos[a.l] ^ pos[b.l]) ? pos[a.l] < pos[b.l] : a.r < b.r; } int las[N], fir[N]; int clr[N], cntc; int last[N]; int calc(int l, int r) { int ret = 0; For(i, l, r) last[a[i]] = 0; For(i, l, r) { if (!last[a[i]]) last[a[i]] = i; else ret = max(ret, i - last[a[i]]); } return ret; } int main() { n = read(); blo = sqrt(n); tot = ceil((double)n / blo); For(i, 1, tot) lb[i] = (i - 1) * blo + 1, rb[i] = i * blo; rb[tot] = n; For(i, 1, n) { pos[i] = (i - 1) / blo + 1; } For(i, 1, n) inp[i] = a[i] = read(); sort(inp + 1, inp + n + 1); int len = unique(inp + 1, inp + n + 1) - inp - 1; For(i, 1, n) a[i] = lower_bound(inp + 1, inp + len + 1, a[i]) - inp; m = read(); For(i, 1, m) { q[i].l = read(); q[i].r = read(); q[i].id = i; } sort(q + 1, q + m + 1, cmp); int i = 1; For(k, 0, tot) { int l = rb[k] + 1, r = rb[k]; int now = 0; for (; pos[q[i].l] == k; i++) { int ql = q[i].l, qr = q[i].r; int tmp = 0; if (pos[ql] == pos[qr]) { ans[q[i].id] = calc(ql, qr); continue; } while (r < qr) { r++; if (!fir[a[r]]) fir[a[r]] = r, clr[++cntc] = a[r]; now = max(now, r - fir[a[r]]); las[a[r]] = r; } tmp = now; while (l > ql) { l--; if (las[a[l]]) now = max(now, las[a[l]] - l); else las[a[l]] = l; } ans[q[i].id] = now; now = tmp; while (l < rb[k] + 1) { if (las[a[l]] == l) las[a[l]] = 0; l++; } } For(j, 1, cntc) fir[clr[j]] = las[clr[j]] = 0; cntc = 0; } For(i, 1, m) printf("%d\n", ans[i]); return 0; }