用 $O(n\log n)$ 求解平面最近点对

最近点对是一个很典型的几何分治问题。题目本身很直接:给定平面上的 \(n\) 个点,找出欧氏距离最小的一对点。

暴力做法不难写,两层循环枚举所有点对,复杂度是 \(O(n^2)\)。当 \(n\) 只有几千时还可以接受,但如果 \(n\) 到了 \(10^5\) 量级,就必须换思路。

这个问题的经典解法是分治,时间复杂度可以做到 \(O(n\log n)\)。它有一个容易被写错的地方:递归时不能在每一层都重新按 \(y\) 排序,否则复杂度会退化到 \(O(n\log^2 n)\)

一维的情况

如果点都在一条直线上,最近点对很简单:按坐标排序,答案只可能出现在相邻两个点之间。排序花费 \(O(n\log n)\),扫描一遍花费 \(O(n)\)

二维平面麻烦在于,按 \(x\) 排序后,最近点对不一定是相邻点。例如下面这些点:

(0, 100)
(1, 0)
(2, 100)
(3, 0.1)

\(x\) 排序相邻并不能直接解决问题。二维里既要考虑横向距离,也要考虑纵向距离。不过“排序后只看局部”的思想仍然有用,只是需要更精细地定义这个“局部”。

分治法

先把所有点按 \(x\) 坐标排序。对于当前区间 \([l,r]\),取中点 \(mid\),把点集分成左右两半:

[l ... mid]     [mid + 1 ... r]

递归求左半部分最近距离 \(d_l\),右半部分最近距离 \(d_r\),当前答案先设为 \(d=\min(d_l,d_r)\)

问题在于,真正的最近点对可能一左一右跨过分界线。如果不处理这个情况,分治结果是不完整的。

image

设分界线的横坐标为 \(x_m\)。对于某个点 \(p_i\),如果它到分界线的横向距离已经不小于 \(d\),也就是 \(|p_i.x-x_m|\ge d\),那么它不可能和另一侧的点组成比 \(d\) 更近的点对。因为欧氏距离至少不小于横向距离。

所以跨区间检查只需要看一条宽度为 \(2d\) 的竖条区域:

\[|p_i.x - x_m| < d \]

把这些点取出来,记为 strip

为什么 strip 里不用两两枚举

如果直接对 strip 两两枚举,最坏情况下仍然可能是 \(O(n^2)\)。关键优化是:把 strip\(y\) 坐标排序。对于其中每个点,只检查它后面 \(y\) 差小于 \(d\) 的点。

代码通常写成这样:

for (int i = 0; i < strip.size(); i++) {
    for (int j = i + 1; j < strip.size(); j++) {
        if ((strip[j].y - strip[i].y) * (strip[j].y - strip[i].y) >= best) break;
        best = min(best, dist2(strip[i], strip[j]));
    }
}

image

这里 best 存的是距离平方,所以判断也用平方。

这个循环看起来像双重循环,但总复杂度仍然是线性的。原因来自一个几何约束:在递归已经得到 \(d\) 的前提下,左半部分内部任意两点距离不小于 \(d\),右半部分内部任意两点距离也不小于 \(d\)。在一个宽度 \(2d\)、高度 \(d\) 的小矩形区域里,能放下的“互相距离至少为 \(d\)”的点数是常数级的。

更常见的说法是:对 strip 中按 \(y\) 排序后的每个点,只需要向后检查常数个点,通常不超过 \(7\) 个。实际实现里不必手动写死 7,用 y 差剪枝即可,几何性质保证了它不会让总复杂度失控。

image

避免 \(O(n\log^2 n)\) 的细节

一个朴素写法是:每次递归合并时,把 strip 重新按 \(y\) 排序。这样每一层递归都会产生一次排序,复杂度递推接近:

\[T(n)=2T(n/2)+O(n\log n) \]

解出来是 \(O(n\log^2 n)\)

想做到 \(O(n\log n)\),需要让每个递归区间在返回时已经按 \(y\) 坐标排好序。这样父区间只需要用线性时间把左右两个按 \(y\) 排序的子区间归并起来。

递推式就变成:

\[T(n)=2T(n/2)+O(n) \]

于是整体复杂度为 \(O(n\log n)\)

这个技巧很像归并排序:递归前区间按 \(x\) 排序,用来切分左右半边;递归返回后区间按 \(y\) 排序,用来在线性时间处理跨分界线的候选点。

image

实现

下面实现中用距离平方比较,避免在循环里频繁调用 sqrt。如果输入坐标是整数,用 long double 存距离平方比较稳妥;如果坐标范围明确在 \(10^9\) 内,long long 也足够存平方和。

#include <bits/stdc++.h>
using namespace std;

struct Point {
    long double x, y;
};

bool cmpX(const Point& a, const Point& b) {
    if (a.x != b.x) return a.x < b.x;
    return a.y < b.y;
}

bool cmpY(const Point& a, const Point& b) {
    if (a.y != b.y) return a.y < b.y;
    return a.x < b.x;
}

long double dist2(const Point& a, const Point& b) {
    long double dx = a.x - b.x;
    long double dy = a.y - b.y;
    return dx * dx + dy * dy;
}

long double closestPair(vector<Point>& p, int l, int r) {
    if (r - l <= 3) {
        long double best = numeric_limits<long double>::infinity();

        for (int i = l; i <= r; i++) {
            for (int j = i + 1; j <= r; j++) {
                best = min(best, dist2(p[i], p[j]));
            }
        }

        sort(p.begin() + l, p.begin() + r + 1, cmpY);
        return best;
    }

    int mid = (l + r) / 2;
    long double midX = p[mid].x;

    long double leftBest = closestPair(p, l, mid);
    long double rightBest = closestPair(p, mid + 1, r);
    long double best = min(leftBest, rightBest);

    inplace_merge(
        p.begin() + l,
        p.begin() + mid + 1,
        p.begin() + r + 1,
        cmpY
    );

    vector<Point> strip;
    strip.reserve(r - l + 1);

    for (int i = l; i <= r; i++) {
        long double dx = p[i].x - midX;
        if (dx * dx < best) {
            strip.push_back(p[i]);
        }
    }

    for (int i = 0; i < (int)strip.size(); i++) {
        for (int j = i + 1; j < (int)strip.size(); j++) {
            long double dy = strip[j].y - strip[i].y;
            if (dy * dy >= best) break;

            best = min(best, dist2(strip[i], strip[j]));
        }
    }

    return best;
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);

    int n;
    cin >> n;

    vector<Point> p(n);
    for (int i = 0; i < n; i++) {
        cin >> p[i].x >> p[i].y;
    }

    if (n < 2) {
        cout << "0.000000\n";
        return 0;
    }

    sort(p.begin(), p.end(), cmpX);

    long double answer2 = closestPair(p, 0, n - 1);
    cout << fixed << setprecision(6) << sqrt((double)answer2) << '\n';

    return 0;
}

midX 必须在递归调用前取出来。因为递归返回后,当前区间会被调整为按 \(y\) 排序,此时 p[mid] 已经不再表示按 \(x\) 排序时的中点。

基础区间里不只是暴力枚举,还要把 [l,r] 排成按 \(y\) 有序。这样递归的上层才能用 inplace_merge 线性归并。这个行为是整个 \(O(n\log n)\) 复杂度成立的前提。

strip 不需要再排序,因为当前区间 [l,r] 已经通过 inplace_merge\(y\) 排好了。按顺序扫描 [l,r] 收集出来的 strip 自然也是按 \(y\) 排序的。

如果存在两个完全相同的点,答案会变成 \(0\)。代码可以自然处理这种情况。一旦 best 变成 \(0\),后续判断 dx * dx < best 不会再加入候选点,但答案已经是最小可能值,不影响正确性。

复杂度分析

初始按 \(x\) 排序需要 \(O(n\log n)\)

递归过程中,每一层做三件事:归并两个按 \(y\) 排序的子区间、收集 strip、检查跨分界线候选点。前两者显然是线性的,第三者由于每个点只会检查常数个后继点,也是线性的。因此每层总工作量是 \(O(n)\)

递归深度是 \(O(\log n)\),所以总复杂度为 \(O(n\log n)\)。额外空间主要来自递归栈和 strip,如果不做更细的复用,空间复杂度可以看作 \(O(n)\)

posted @ 2026-05-24 21:17  Ofnoname  阅读(19)  评论(2)    收藏  举报