AtCoder Beginner Contest 301 Ex Difference of Distance
基础图论。
考虑快速求出 \(d(s, t)\),那么边权要 \(+1\) 的边仅当边权等于 \(d(s, t)\) 时答案可能为 \(1\)。求 \(d(s, t)\) 可以建出 Kruskal 重构树,查两点 \(\text{LCA}\) 点权得出。
考虑把询问离线,把相同的 \(d(s, t)\) 的询问和相同边权的边放在一起考虑。边权 \(< d(s, t)\) 的边在之前合并。现在问题变成了有一个图,询问假设删掉其中一条边,会不会使得两点不连通。边双缩点后,如果这条边在一个边双内,或者两点在同一边双内,答案显然为 \(0\),否则查这条边是否在两点路径上即可。
时间复杂度 \(O(m \log m)\)。
code
// Problem: Ex - Difference of Distance
// Contest: AtCoder - パナソニックグループプログラミングコンテスト2023(AtCoder Beginner Contest 301)
// URL: https://atcoder.jp/contests/abc301/tasks/abc301_h
// Memory Limit: 1024 MB
// Time Limit: 5000 ms
//
// Powered by CP Editor (https://cpeditor.org)
#include <bits/stdc++.h>
#define pb emplace_back
#define fst first
#define scd second
#define mems(a, x) memset((a), (x), sizeof(a))
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef double db;
typedef long double ldb;
typedef pair<int, int> pii;
const int maxn = 1000100;
const int logn = 22;
int n, m, q, ntot, head[maxn], len, a[maxn], b[maxn], c[maxn];
int fa[maxn], sz[maxn], son[maxn], dep[maxn], ans[maxn];
int top[maxn], st[maxn], ed[maxn], tim;
bool vis[maxn];
struct node {
int x, y, k, id;
node(int a = 0, int b = 0, int c = 0, int d = 0) : x(a), y(b), k(c), id(d) {}
};
vector<node> qq[maxn];
struct E {
int u, v, d, id;
} G[maxn];
bool cmp(E a, E b) {
return a.d < b.d;
}
struct edge {
int to, next;
} edges[maxn];
inline void add_edge(int u, int v) {
edges[++len].to = v;
edges[len].next = head[u];
head[u] = len;
}
namespace DSU {
int fa[maxn];
void init(int n) {
for (int i = 1; i <= n; ++i) {
fa[i] = i;
}
}
int find(int x) {
return fa[x] == x ? x : fa[x] = find(fa[x]);
}
inline void merge(int x, int y) {
x = find(x);
y = find(y);
if (x != y) {
fa[x] = y;
}
}
}
int dfs(int u, int f, int d) {
st[u] = ++tim;
fa[u] = f;
sz[u] = 1;
dep[u] = d;
int maxson = -1;
for (int i = head[u]; i; i = edges[i].next) {
int v = edges[i].to;
if (v == f) {
continue;
}
sz[u] += dfs(v, u, d + 1);
if (sz[v] > maxson) {
son[u] = v;
maxson = sz[v];
}
}
ed[u] = tim;
return sz[u];
}
void dfs2(int u, int tp) {
top[u] = tp;
vis[u] = 1;
if (!son[u]) {
return;
}
dfs2(son[u], tp);
for (int i = head[u]; i; i = edges[i].next) {
int v = edges[i].to;
if (!vis[v]) {
dfs2(v, v);
}
}
}
inline int qlca(int x, int y) {
while (top[x] != top[y]) {
if (dep[top[x]] < dep[top[y]]) {
swap(x, y);
}
x = fa[top[x]];
}
if (dep[x] > dep[y]) {
swap(x, y);
}
return x;
}
int lsh[maxn], tot, stk[maxn], ttop;
int dfn[maxn], low[maxn], times, col[maxn], cnt;
vector<pii> g[maxn];
vector<int> ng[maxn];
void tarjan(int u, int fa) {
dfn[u] = low[u] = ++times;
stk[++ttop] = u;
for (pii p : g[u]) {
int v = p.fst, id = p.scd;
if (id == fa) {
continue;
}
if (!dfn[v]) {
tarjan(v, id);
low[u] = min(low[u], low[v]);
} else {
low[u] = min(low[u], dfn[v]);
}
}
if (dfn[u] == low[u]) {
++cnt;
while (1) {
int x = stk[ttop--];
col[x] = cnt;
if (x == u) {
break;
}
}
}
}
int f[maxn][logn], de[maxn];
void dfs3(int u, int fa) {
f[u][0] = fa;
for (int i = 1; i <= 20; ++i) {
f[u][i] = f[f[u][i - 1]][i - 1];
}
de[u] = de[fa] + 1;
for (int v : ng[u]) {
if (v == fa) {
continue;
}
dfs3(v, u);
}
}
inline int jump(int x, int k) {
for (int i = 0; i <= 20; ++i) {
if (k & (1 << i)) {
x = f[x][i];
}
}
return x;
}
inline int glca(int x, int y) {
if (de[x] < de[y]) {
swap(x, y);
}
for (int i = 20; ~i; --i) {
if (de[f[x][i]] >= de[y]) {
x = f[x][i];
}
}
if (x == y) {
return x;
}
for (int i = 20; ~i; --i) {
if (f[x][i] != f[y][i]) {
x = f[x][i];
y = f[y][i];
}
}
return f[x][0];
}
void solve() {
scanf("%d%d", &n, &m);
for (int i = 1; i <= m; ++i) {
scanf("%d%d%d", &G[i].u, &G[i].v, &G[i].d);
G[i].id = i;
}
ntot = n;
sort(G + 1, G + m + 1, cmp);
DSU::init(n + m);
for (int i = 1; i <= m; ++i) {
int u = G[i].u, v = G[i].v, d = G[i].d, id = G[i].id;
int x = DSU::find(u), y = DSU::find(v);
c[id] = i;
if (x != y) {
int z = ++ntot;
b[id] = ntot;
DSU::fa[x] = DSU::fa[y] = z;
add_edge(z, x);
add_edge(z, y);
a[z] = d;
}
}
int rt = ntot;
dfs(rt, 0, 1);
dfs2(rt, rt);
scanf("%d", &q);
for (int i = 1, x, y, k; i <= q; ++i) {
scanf("%d%d%d", &k, &x, &y);
if (!b[k]) {
continue;
}
int lca = qlca(x, y);
if (G[c[k]].d == a[lca]) {
qq[a[lca]].pb(x, y, c[k], i);
}
}
DSU::init(n);
for (int i = 1, j = 1; i <= m; i = (++j)) {
while (j < m && G[j + 1].d == G[i].d) {
++j;
}
tot = 0;
for (int k = i; k <= j; ++k) {
int u = G[k].u, v = G[k].v;
lsh[++tot] = DSU::find(u);
lsh[++tot] = DSU::find(v);
}
sort(lsh + 1, lsh + tot + 1);
tot = unique(lsh + 1, lsh + tot + 1) - lsh - 1;
ttop = times = cnt = 0;
for (int u = 1; u <= tot; ++u) {
dfn[u] = low[u] = col[u] = 0;
vector<pii>().swap(g[u]);
vector<int>().swap(ng[u]);
de[u] = f[u][0] = 0;
}
for (int k = i; k <= j; ++k) {
int u = G[k].u, v = G[k].v;
int x = lower_bound(lsh + 1, lsh + tot + 1, DSU::find(u)) - lsh;
int y = lower_bound(lsh + 1, lsh + tot + 1, DSU::find(v)) - lsh;
if (x != y) {
g[x].pb(y, k);
g[y].pb(x, k);
}
}
for (int u = 1; u <= tot; ++u) {
if (!dfn[u]) {
tarjan(u, -1);
}
}
set<pii> st;
for (int u = 1; u <= tot; ++u) {
for (pii p : g[u]) {
int v = p.fst;
if (col[u] != col[v]) {
int x = col[u], y = col[v];
if (st.find(make_pair(x, y)) == st.end()) {
st.insert(make_pair(x, y));
ng[x].pb(y);
}
}
}
}
for (int u = 1; u <= cnt; ++u) {
if (!de[u]) {
dfs3(u, 0);
}
}
for (node p : qq[G[i].d]) {
int u = p.x, v = p.y, k = p.k, id = p.id;
int x = lower_bound(lsh + 1, lsh + tot + 1, DSU::find(u)) - lsh;
int y = lower_bound(lsh + 1, lsh + tot + 1, DSU::find(v)) - lsh;
int xx = lower_bound(lsh + 1, lsh + tot + 1, DSU::find(G[k].u)) - lsh;
int yy = lower_bound(lsh + 1, lsh + tot + 1, DSU::find(G[k].v)) - lsh;
if (col[x] != col[y] && col[xx] != col[yy]) {
x = col[x];
y = col[y];
xx = col[xx];
yy = col[yy];
if (de[xx] < de[yy]) {
swap(xx, yy);
}
int lca = glca(x, y);
if (de[xx] <= de[x] && de[yy] >= de[lca]) {
if (jump(x, de[x] - de[xx]) == xx && jump(yy, de[yy] - de[lca]) == lca) {
ans[id] = 1;
}
}
if (de[xx] <= de[y] && de[yy] >= de[lca]) {
if (jump(y, de[y] - de[xx]) == xx && jump(yy, de[yy] - de[lca]) == lca) {
ans[id] = 1;
}
}
}
}
for (int k = i; k <= j; ++k) {
int u = G[k].u, v = G[k].v;
DSU::merge(u, v);
}
}
for (int i = 1; i <= q; ++i) {
printf("%d\n", ans[i]);
}
}
int main() {
int T = 1;
// scanf("%d", &T);
while (T--) {
solve();
}
return 0;
}

浙公网安备 33010602011771号