AtCoder ABC315 G - Ai + Bj + Ck = X (1 <= i, j, k <= N) 题解
Ai + Bj + Ck = X (1 <= i, j, k <= N)
Description
给定整数 \(n,a,b,c,x\),求满足 \(ai+bj+ck=x\) 的有序正整数三元组 \((i,j,k)\) 的个数,其中 \(1 \le i,j,k \le n\)。
\(1 \le n \le 10^6,\ 1 \le a,b,c \le 10^9,\ 1 \le x \le 3\times 10^{15}\)。
Solution
三元组不便于直接计数。
注意到可以抽出 \(a\),固定 \(b\) 和 \(c\) 进行计数,转化为二元一次不定方程问题。
具体地,原式可变为
\[bj+ck=x-ai
\]
设 \(x-ai = y\),枚举 \(i\),对方程 \(bj+ck=y\) 的解 \((j,k)\) 计数即可,使用扩欧算法(详见上一篇文章)。
对于 \(i\) 的范围:
- 当 \(j=k=0\) 时,\(i = \dfrac xa\);
- 当 \(j=k=n\) 时,\(i = \dfrac{x-(b+c)n}{a}\)。
- 同时应当满足 \(i \in [1,n]\)。
界定上述范围可得 \(i \in \Big[\max\Big(1, \Big \lceil \dfrac{x-(b+c)n}{a} \Big \rceil\Big),\ \min\Big(n,\Big \lfloor \dfrac xa \Big \rfloor\Big)\Big]\)。
需要 __int128。
Code
#include <bits/stdc++.h>
#define int __int128
#define inf 1e18
#define debug cout << '!';
#define filein(x) freopen(#x".in", "r", stdin);
#define fileout(x) freopen(#x".out", "w", stdout);
#define file(x) filein(x) fileout(x)
#define Fast_IO
using namespace std;
#ifdef Fast_IO
inline int read() {
int x = 0, f = 1; char c = getchar();
while (c < '0' or c > '9') { if (c == '-') f = -1; c = getchar(); }
while (c >= '0' and c <= '9') { x = x * 10 + c - '0'; c = getchar(); }
return x * f;
}
void write(int x) {
if (x < 0) putchar('-'), x = -x;
if (x > 9) write(x / 10);
putchar(x % 10 + '0'); return;
}
#endif
int T, N, A, B, C, X;
int gcd(int x, int y) {
return y == 0 ? x : gcd(y, x % y);
}
int exgcd(int a, int b, int &x, int &y) {
if (b == 0) {
x = 1, y = 0;
return a;
}
int res = exgcd(b, a % b, y, x);
y -= a / b * x;
return res;
}
int inv(int a, int mod) {
int x, y;
exgcd(a, mod, x, y);
x %= mod;
if (x < 0) x += mod;
return x;
}
int fdiv(int a, int b) {
return a >= 0 ? a / b : (a - b + 1) / b;
}
int cdiv(int a, int b) {
return a >= 0 ? (a + b - 1) / b : a / b;
}
int solve2(int a, int b, int x) {
int il = max((int)1, cdiv(x - b * N, a)), ir = min(N, fdiv(x - b, a));
if (il > ir) return 0;
int g = gcd(a, b);
if (x % g) return 0;
a /= g, b /= g, x /= g;
int i0 = (x % b) * inv(a, b) % b;
int kl = cdiv(il - i0, b), kr = fdiv(ir - i0, b);
if (kl > kr) return 0;
else return kr - kl + 1;
}
int solve3(int a, int b, int c, int x) {
int res = 0;
for (int k = max((int)1, cdiv( x - (a + b) * N, c)); k <= min(N, fdiv(x, c)); k++) {
res += solve2(a, b, x - k * c);
}
return res;
}
signed main() {
cin.tie(0) -> sync_with_stdio(0);
// T = read();
T = 1;
while (T--) {
N = read(), A = read(), B = read(), C = read(), X = read();
write(solve3(A, B, C, X));
putchar('\n');
}
return 0;
}

浙公网安备 33010602011771号