BZOJ2118 墨墨的等式(同余最短路)
Description
给定 a 1 ∼ a n a_1 \sim a_n a1∼an 和 B B B 的取值范围 [ l , r ] [l,r] [l,r]。求有多少个 B B B 满足有非负整数 x 1 ∼ x n x_1 \sim x_n x1∼xn 使得 B = ∑ i = 1 n a i x i B = \sum_{i = 1}^n a_ix_i B=∑i=1naixi。时间限制 10s。
1 ≤ n ≤ 12 , 0 ≤ a i ≤ 5 × 1 0 5 , 1 ≤ l ≤ r ≤ 1 0 12 1 \leq n \leq 12, 0 \leq a_i \leq 5 \times 10^5, 1 \leq l \leq r \leq 10^{12} 1≤n≤12,0≤ai≤5×105,1≤l≤r≤1012。
Solution
大大凯的疑惑? 前版本为小凯的疑惑,跳楼机。
用最小的 a i a_i ai 作为模数 p p p。如果一个 B B B 有解那么 B m o d p + k × p B \bmod p +k \times p Bmodp+k×p 也有解,其中 k k k 是一个非负整数,它们同余。
所以用 d i s i dis_i disi 为 a 2 ∼ a n a_2 \sim a_n a2∼an 跑出最小解 B B B 满足 B m o d p = i B \bmod p = i Bmodp=i。可以跑最短路求出。那么对于每一个 i i i,看看能加 p p p 得到多少答案,根据加法原理将所有 i i i 答案相加。
点有 p − 1 p-1 p−1 个而边的规模为 n × p n \times p n×p,所以时间复杂度为 O ( n p log n p ) O(np \log np) O(nplognp)。可以用配对堆或斐波那契堆优化到 O ( n log n + n p ) O(n \log n + np) O(nlogn+np)。我用的 pb_ds 自带的配对堆,但是因为 10s 的时限用 priority_queue 也行。好像 spfa 也比较快。
Code
#include <bits/stdc++.h>
#include <ext/pb_ds/priority_queue.hpp>
#define int long long
using namespace std;
using namespace __gnu_pbds;
typedef long long ll;
typedef __gnu_pbds :: priority_queue <pair<int, int>, greater<pair<int, int> > ,pairing_heap_tag> pair_heap;
const int N = 1e6 + 5, M = 1e7 + 5, INF = 0x3f3f3f3f;
inline int read() {
int x = 0, f = 0; char ch = 0;
while (!isdigit(ch)) f |= ch == '-', ch = getchar();
while (isdigit(ch)) x = (x << 3) + (x << 1) + (ch ^ 48), ch = getchar();
return f ? -x : x;
}
int a[N];
int dis[N], n, l, r;
pair_heap :: point_iterator id[N];
void Dijkstra() {
pair_heap q; memset(dis, 0x7f, sizeof(dis));
dis[0] = 0; id[0] = q.push(make_pair(0, 0));
while(!q.empty()) {
int x = q.top().second; q.pop();
for (int j = 2; j <= n; j++) {
int y = (x + a[j]) % a[1];
if (dis[y] > dis[x] + a[j]) {
dis[y] = dis[x] + a[j];
if (id[y] != NULL) q.modify(id[y], make_pair(dis[y], y));
else id[y] = q.push(make_pair(dis[y], y));
}
}
}
}
signed main() {
n = read(), l = read() - 1, r = read();
for (int i = 1; i <= n; i++) a[i] = read();
sort(a + 1, a + n + 1);
Dijkstra();
int ans = 0;
for (int i = 0; i < a[1]; i++) {
if (dis[i] <= l) ans -= (l - dis[i]) / a[1] + 1;
if (dis[i] <= r) ans += (r - dis[i]) / a[1] + 1;
}
printf("%lld\n", ans);
return 0;
}

浙公网安备 33010602011771号