【LG5330】[SNOI2019]数论

【LG5330】[SNOI2019]数论

题面

洛谷

题目大意:

给定集合\(\mathbb {A,B}\)

问有多少个小于\(T\)的非负整数\(x\)满足:\(x\)除以\(P\)的余数属于\(\mathbb A\)\(x\)除以\(Q\)的余数属于\(\mathbb B\)

其中\(1\leq |\mathbb A|,|\mathbb B|\leq 10^6,1\leq P,Q\leq 10^6,1\leq T\leq 10^{18}\)

题面

考虑枚举一个\(A\),然后考虑有多少个合法的\(B\)

首先这个数可以写成\(a_i+kP\)的形式,那么它模\(Q\)的值成环。

所以我们预处理每个环内有多少个合法的\(b\),再把\(b\)按照访问顺序记录一下,那么对于每一个\(a\)就可以直接算答案了。

代码

#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cstring> 
#include <cmath> 
#include <algorithm>
#include <vector> 
using namespace std; 
inline int gi() {
    register int data = 0, w = 1; 
    register char ch = 0; 
    while (!isdigit(ch) && ch != '-') ch = getchar(); 
    if (ch == '-') w = -1, ch = getchar(); 
    while (isdigit(ch)) data = 10 * data + ch - '0', ch = getchar(); 
    return w * data; 
} 
const int MAX_N = 1e6 + 5; 
int P, Q, N, M, a[MAX_N], b[MAX_N]; 
long long T, len, p[MAX_N]; 
int val[MAX_N], w[MAX_N], col[MAX_N], pos[MAX_N], cnt; 
vector<int> cir[MAX_N], sum[MAX_N]; 
int dfs(int x) { 
    if (col[x]) return 0; 
    col[x] = cnt, cir[cnt].push_back(x); 
    return val[x] + dfs((x + P) % Q); 
} 
int solve(int l, int x) { return sum[col[x]][pos[x] + l] - sum[col[x]][pos[x]]; } 
int main () {
#ifndef ONLINE_JUDGE 
    freopen("cpp.in", "r", stdin);
#endif 
    P = gi(), Q = gi(), N = gi(), M = gi(); scanf("%lld", &T); 
    for (int i = 1; i <= N; i++) a[i] = gi(); 
    for (int i = 1; i <= M; i++) b[i] = gi(); 
    if (P > Q) swap(P, Q), swap(N, M), swap(a, b); 
    len = Q / __gcd(P, Q); 
    for (int i = 1; i <= M; i++) val[b[i]] = 1; 
    for (int i = 1; i <= N; i++) p[i] = (T - 1 - a[i]) / P; 
    for (int i = 0; i < Q; i++) if (!col[i]) ++cnt, w[cnt] = dfs(i); 
    for (int i = 1; i <= cnt; i++) { 
        for (int j = 0; j < (int)cir[i].size(); j++) pos[cir[i][j]] = j; 
        for (int j = 0, sz = cir[i].size(); j < sz - 1; j++) cir[i].push_back(cir[i][j]); 
        sum[i].push_back(val[cir[i][0]]); 
        for (int j = 1; j < (int)cir[i].size(); j++) sum[i].push_back(sum[i][j - 1] + val[cir[i][j]]); 
    } 
    long long ans = 0; 
    for (int i = 1; i <= N; i++) { 
        ans += p[i] / len * w[col[a[i]]]; 
        ans += solve(p[i] % len, a[i]) + val[a[i]]; 
    } 
    printf("%lld\n", ans); 
    return 0; 
} 
posted @ 2019-11-04 14:28  heyujun  阅读(...)  评论(... 编辑 收藏