如何找到几何中位数

33
问题是:
给定N个二维坐标系中的点,每个点有x和y坐标,找到一个点P(在给定的N个点中)使得其他(N-1)个点到P的距离之和最小。
这个点通常被称为几何中位数。除了朴素的 O(N^2) 方法外,是否存在有效的算法来解决这个问题?

3
@JeroenVuurens说:“我认为这样行不通——我认为对于 [(-L,0), (L,0)]*25 + [(0,1), (0,2), (0,3)] ,其中 L 很大,你会选择 (0,1) 而不是 (0,2)。” - Nabb
3
@MicSim,实际上OP不是寻找几何中位数,尽管问题相似。 - Qnan
2
@Qnan:我没有看到原帖问题和几何中位数之间的区别,您能详细解释一下吗? - Fred Foo
4
实际的几何中位数不一定属于所讨论的点集。 - Qnan
3
你说得对,这里指的是中心点,而不是中位数。 - Fred Foo
显示剩余10条评论
7个回答

25

我曾经用模拟退火算法为本地在线评测机解决了一个类似的问题。这也是官方的解决方案,程序得到了AC。

唯一的区别是我需要找到的点不必是给定的N个点中的一部分。

以下是我的C++代码,而N可以达到50000。该程序在2ghz pentium 4上执行时间为0.1s

// header files for IO functions and math
#include <cstdio>
#include <cmath>

// the maximul value n can take
const int maxn = 50001;

// given a point (x, y) on a grid, we can find its left/right/up/down neighbors
// by using these constants: (x + dx[0], y + dy[0]) = upper neighbor etc.
const int dx[] = {-1, 0, 1, 0};
const int dy[] = {0, 1, 0, -1};

// controls the precision - this should give you an answer accurate to 3 decimals
const double eps = 0.001;

// input and output files
FILE *in = fopen("adapost2.in","r"), *out = fopen("adapost2.out","w");

// stores a point in 2d space
struct punct
{
    double x, y;
};

// how many points are in the input file
int n;

// stores the points in the input file
punct a[maxn];

// stores the answer to the question
double x, y;

// finds the sum of (euclidean) distances from each input point to (x, y)
double dist(double x, double y)
{
    double ret = 0;

    for ( int i = 1; i <= n; ++i )
    {
        double dx = a[i].x - x;
        double dy = a[i].y - y;

        ret += sqrt(dx*dx + dy*dy); // classical distance formula
    }

    return ret;
}

// reads the input
void read()
{
    fscanf(in, "%d", &n); // read n from the first 

    // read n points next, one on each line
    for ( int i = 1; i <= n; ++i )
        fscanf(in, "%lf %lf", &a[i].x, &a[i].y), // reads a point
        x += a[i].x,
        y += a[i].y; // we add the x and y at first, because we will start by approximating the answer as the center of gravity

    // divide by the number of points (n) to get the center of gravity
    x /= n; 
    y /= n;
}

// implements the solving algorithm
void go()
{
    // start by finding the sum of distances to the center of gravity
    double d = dist(x, y);

    // our step value, chosen by experimentation
    double step = 100.0;

    // done is used to keep track of updates: if none of the neighbors of the current
    // point that are *step* steps away improve the solution, then *step* is too big
    // and we need to look closer to the current point, so we must half *step*.
    int done = 0;

    // while we still need a more precise answer
    while ( step > eps )
    {
        done = 0;
        for ( int i = 0; i < 4; ++i )
        {
            // check the neighbors in all 4 directions.
            double nx = (double)x + step*dx[i];
            double ny = (double)y + step*dy[i];

            // find the sum of distances to each neighbor
            double t = dist(nx, ny);

            // if a neighbor offers a better sum of distances
            if ( t < d )
            {
                update the current minimum
                d = t;
                x = nx;
                y = ny;

                // an improvement has been made, so
                // don't half step in the next iteration, because we might need
                // to jump the same amount again
                done = 1;
                break;
            }
        }

        // half the step size, because no update has been made, so we might have
        // jumped too much, and now we need to head back some.
        if ( !done )
            step /= 2;
    }
}

int main()
{
    read();
    go();

    // print the answer with 4 decimal points
    fprintf(out, "%.4lf %.4lf\n", x, y);

    return 0;
}

那么,我认为从您的列表中选择与此算法返回的(x, y)最接近的那个是正确的选择。
该算法利用了维基百科关于几何中位数的段落所说的内容:
“然而,可以使用迭代过程计算出几何中位数的近似值,在每一步中产生更精确的近似值。这种类型的程序可以从样本点到距离之和是一个凸函数这一事实中得出,因为到每个样本点的距离是凸的,并且凸函数的总和仍然是凸的。因此,在每一步减少距离总和的过程中不能陷入局部最优的程序不会被困住。”
其中第一个段落解释了为什么这个算法有效:因为我们正在尝试优化的函数没有任何局部最小值,所以您可以通过迭代改进来贪心地找到最小值。
把它想象成一种二分查找。首先,您近似结果。当读取输入时,一个好的近似值将是重心,然后查看相邻点是否给您提供更好的解决方案。在这种情况下,如果一个点距离当前点step,则被视为相邻点。如果更好,则可以放弃当前点,因为正如我所说,由于您正在尝试最小化的函数的性质,这不会将您困在局部最小值中。
之后,将步长减半,就像二分查找一样,并继续进行,直到您拥有足够好的近似值(由eps常量控制)。
因此,该算法的复杂度取决于您希望结果的精度。

请解释一下,对我来说这比希腊文还难懂! - SexyBeast
它正在寻找实际的几何中位数,我相信。嗯,大约是这样。 - Qnan
1
请尽可能详细地注释代码!我真的需要完全理解它。我知道这是要求太多了,但请帮帮我,作为回报,明天我会奖励你丰厚的赏金! - SexyBeast
2
@Cupidvogel 我已经对代码进行了注释,如果还有其他需要帮助的地方,请告诉我。请注意,我仍然不知道这是否回答了您最初的问题。我不知道从您的集合中选择与我的算法返回的点最接近的点会有多大作用。 - IVlad
5
这个答案不是模拟退火算法,而是牛顿搜索中的一阶泰勒级数项;有时也称为一阶搜索;起始位置是平均值。 - koan
显示剩余12条评论

10
似乎使用欧几里得距离时,问题很难在优于O(n^2)的时间内解决。但是,最小化到其他点的曼哈顿距离之和或最小化到其他点的欧几里得距离平方和的点可以在O(n log n)的时间内找到(假设两个数字相乘为O(1))。让我不要脸地从最近的post中复制/粘贴我的曼哈顿距离解决方案:
Create a sorted array of x-coordinates and for each element in the array compute the "horizontal" cost of choosing that coordinate. The horizontal cost of an element is the sum of distances to all the points projected onto the X-axis. This can be computed in linear time by scanning the array twice (once from left to right and once in the reverse direction). Similarly create a sorted array of y-coordinates and for each element in the array compute the "vertical" cost of choosing that coordinate.
Now for each point in the original array, we can compute the total cost to all other points in O(1) time by adding the horizontal and vertical costs. So we can compute the optimal point in O(n). Thus the total running time is O(n log n).
We can follow a similar approach for computing the point that minimizes the sum of squares of Euclidean distances to other points. Let the sorted x-coordinates be: x1, x2, x3, ..., xn. We scan this list from left to right and for each point xi we compute:

li = 到 xi 左侧所有元素的距离之和 = (xi-x1) + (xi-x2) + .... + (xi-xi-1) , 而

sli = 到 xi 左侧所有元素的距离平方和 = (xi-x1)^2 + (xi-x2)^2 + .... + (xi-xi-1)^2

注意,给定 li 和 sli,我们可以按照如下方式在 O(1) 时间内计算出 li+1 和 sli+1:

令 d = xi+1-xi。则:

li+1 = li + id and sli+1 = sli + id^2 + 2*i*d

因此,我们可以通过从左到右扫描来线性计算所有的li和sli。类似地,对于每个元素,我们可以在线性时间内计算出ri:到所有右侧元素的距离之和和sri:到所有右侧元素的距离平方和。将每个i的sri和sli相加,以线性时间计算所有元素的水平距离平方和。同样地,计算所有元素的垂直距离平方和。

然后,我们可以扫描原始点数组,并像以前一样找到最小化垂直和水平距离平方和的点。


4
我想起教授们讲学生回答以前的考试题,希望得到分数的故事。知道这可以用其他指标高效地完成是好的,但这并没有回答问题。 - Nabb
当然,它并没有解决原始问题,我在一开始就提到了。我认为 OP 可能会感兴趣,因为这个解决方案可以很好地近似原始问题(因为指标相似),并且可以非常高效地找到。 - krjampani
@srbh.kmr 不一定。考虑实数线上的5个一维点:0, 0, a, a+b, a+b+c。使得到其他点距离之和最小的点是在a处。但是,如果2c > b + 4a,则使得到其他点距离平方和最小的点是在a+b处。 - krjampani
嗨@krjampani,关于“曼哈顿距离之和”,我不明白为什么需要“排序”,为什么不能只使用leftRightMemo[]和rightLeftMemo[]来计算曼哈顿距离之和,然后对于点P(x,y),x坐标的总和= leftRightMemo[i] + rightLeftMemo[i] - x? - Zhaonan
leftRightMemo和rightLeftMemo是什么?你如何计算它们? - krjampani
显示剩余2条评论

6

如前所述,使用的算法类型取决于你测量距离的方式。由于你的问题没有指定这种测量方式,因此以下是曼哈顿距离欧几里得距离的平方的C语言实现。对于二维点,请使用dim = 2。复杂度为O(n log n)

曼哈顿距离

double * geometric_median_with_manhattan(double **points, int N, int dim) {
    for (d = 0; d < dim; d++) {
        qsort(points, N, sizeof(double *), compare);
        double S = 0;
        for (int i = 0; i < N; i++) {
            double v = points[i][d];
            points[i][dim] += (2 * i - N) * v - 2 * S;
            S += v;
        }
    }
    return min(points, N, dim);
}

简要说明:我们可以按维度(在您的情况下为2)对距离进行总结。假设我们有N个点,一个维度中的值为v_0,..,v_(N-1),并且T = v_0 +..+ v_(N-1)。然后对于每个值v_i,我们都有S_i = v_0 .. v_(i-1)。现在,我们可以通过将左侧的这些值相加来表示该值的曼哈顿距离:i * v_i-S_i和右侧:T-S_i-(N-i)*v_i,其结果为(2 * i - N) * v_i - 2 * S_i + T。将T添加到所有元素中不会改变顺序,因此我们将其省略。并且可以即时计算S_i。
下面是将其转化为实际C程序的其余代码:
#include <stdio.h>
#include <stdlib.h>

int d = 0;
int compare(const void *a, const void *b) {
    return (*(double **)a)[d] - (*(double **)b)[d];
}

double * min(double **points, int N, int dim) {
    double *min = points[0];
    for (int i = 0; i < N; i++) {
        if (min[dim] > points[i][dim]) {
            min = points[i];
        }
    }
    return min;
}

int main(int argc, const char * argv[])
{
    // example 2D coordinates with an additional 0 value
    double a[][3] = {{1.0, 1.0, 0.0}, {3.0, 1.0, 0.0}, {3.0, 2.0, 0.0}, {0.0, 5.0, 0.0}};
    double *b[] = {a[0], a[1], a[2], a[3]};
    double *min = geometric_median_with_manhattan(b, 4, 2);
    printf("geometric median at {%.1f, %.1f}\n", min[0], min[1]);
    return 0;
}

平方欧几里得距离

double * geometric_median_with_square(double **points, int N, int dim) {
    for (d = 0; d < dim; d++) {
        qsort(points, N, sizeof(double *), compare);
        double T = 0;
        for (int i = 0; i < N; i++) {
            T += points[i][d];
        }
        for (int i = 0; i < N; i++) {
            double v = points[i][d];
            points[i][dim] += v * (N * v - 2 * T);
        }
    }
    return min(points, N, dim);
}

简短解释:与之前的方法基本相同,但是推导略微复杂。假设TT = v_0^2 + .. + v_(N-1)^2,我们得到TT + N * v_i^2 - 2 * v_i^2 * T。再次将TT添加到所有内容中,以便省略它。如需更多说明,请提出要求。


2
第一步:按照x维度对点集进行排序(nlogn)。 第二步:计算每个点与其左侧所有点之间的x距离。
xLDist[0] := 0
for i := 1 to n - 1
       xLDist[i] := xLDist[i-1] + ( ( p[i].x - p[i-1].x ) * i)

第三步:计算每个点与其右侧所有点之间的x距离。
xRDist[n - 1] := 0
for i := n - 2 to 0
       xRDist[i] := xRDist[i+1] + ( ( p[i+1].x - p[i].x ) * i)  

第四步:将它们相加,你就可以得到每个点到其他N-1个点的总x距离。
for i := 0 to n - 1
       p[i].xDist = xLDist[i] + xRDist[i]

重复步骤1、2、3、4,使用y维度来获得p[i].yDist。 最小的xDist和yDist之和的点是答案。 总复杂度为O(nlogn)。 在C++中的答案 进一步解释: 思路是重复利用先前点的已计算总距离。 假设我们有3个排序的点ABCD,我们可以看到D到它之前的其他点的左侧总距离为: AD + BD + CD = (AC + CD) + (BC + CD) + CD = AC + BC + 3CD 其中(AC + BC)是C到它之前的其他点的左侧总距离,我们利用了这一点,只需要计算ldist(C) + 3CD。

2
我实现了Weiszfeld方法(我知道这不是你要找的,但它可能有助于近似你的点),复杂度为O(N*M/k),其中N是点的数量,M是点的维数(在你的情况下为2),k是所需误差。

https://github.com/j05u3/weiszfeld-implementation


0

您可以将问题解决为凸规划(目标函数并不总是凸的)。可以使用迭代方法(如L-BFGS)来解决凸规划。每次迭代的成本为O(N),通常所需的迭代次数不多。减少所需迭代次数的一个重要点是我们知道最优答案是输入中的一个点。因此,当优化结果接近于输入点之一时,可以停止优化。


请详细说明一下,我刚才听不懂你说的话!提供伪代码以及一些解释和注释会很好。 - SexyBeast

-1
我们需要找到的答案是几何中位数。
C++ 代码:
#include <bits/stdc++.h>
using namespace std;
int main()
{
    int n;
    cin >> n;

    int a[n],b[n];
    for(int i=0;i<n;i++) 
        cin >> a[i] >> b[i];
    int res = 0;
    sort(a,a+n);
    sort(b,b+n);

    int m1 = a[n/2];
    int m2 = b[n/2];

    for(int i=0;i<n;i++) 
        res += abs(m1 - a[i]);
    for(int i=0;i<n;i++) 
        res += abs(m2 - b[i]);

    cout << res << '\n';
}

“除了朴素的O(N^2)算法外,有没有更有效的算法来解决这个问题?” - phoenixstudio
这是O(nlgn)的复杂度。 - Amruth
我修复了你的代码,至少现在它可以编译了。不过最好还是加上一些解释。 - General Grievance

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接