最近点对实现 Python

8

我正在尝试使用分治算法在Python中实现最近点对问题,除了在某些输入情况下出现错误答案外,一切似乎都正常。我的代码如下:

def closestSplitPair(Px,Py,d):
    X = Px[len(Px)-1][0]
    Sy = [item for item in Py if item[0]>=X-d and item[0]<=X+d]
    best,p3,q3 = d,None,None
    for i in xrange(0,len(Sy)-2):
        for j in xrange(1,min(7,len(Sy)-1-i)):
            if dist(Sy[i],Sy[i+j]) < best:
                best = (Sy[i],Sy[i+j])
                p3,q3 = Sy[i],Sy[i+j]
    return (p3,q3,best)

我通过以下递归函数调用了上面的函数:

def closestPair(Px,Py): """Px and Py are input arrays sorted according to
their x and y coordinates respectively"""
    if len(Px) <= 3:
        return min_dist(Px)
    else:
        mid = len(Px)/2
        Qx = Px[:mid] ### x-sorted left side of P
        Qy = Py[:mid] ### y-sorted left side of P
        Rx = Px[mid:] ### x-sorted right side of P
        Ry = Py[mid:] ### y-sorted right side of P
        (p1,q1,d1) = closestPair(Qx,Qy)
        (p2,q2,d2) = closestPair(Rx,Ry)
        d = min(d1,d2)
        (p3,q3,d3) = closestSplitPair(Px,Py,d)
        return min((p1,q1,d1),(p2,q2,d2),(p3,q3,d3),key=lambda tup: tup[2])

其中min_dist(P)是最近点对算法的暴力实现,用于包含3个或更少元素的列表P,并返回一个包含最接近点对及其距离的元组。

如果我的输入为P = [(0,0),(7,6),(2,20),(12,5),(16,16),(5,8),(19,7),(14,22),(8,19),(7,29),(10,11),(1,13)],那么输出应为((5,8),(7,6),2.8284271),这是正确的输出。但当我的输入为P = [(94, 5), (96, -79), (20, 73), (8, -50), (78, 2), (100, 63), (-14, -69), (99, -8), (-11, -7), (-78, -46)]时,我得到的输出为((78, 2), (94, 5), 16.278820596099706),而正确的输出应该是((94, 5), (99, -8), 13.92838827718412)


@jonrsharpe 我正在生成 P,代码为 P = [ ( random.uniform(-100, 100), random.uniform(-100, 100) ) for k in range(10000) ],然后与暴力方法输出进行比较,在某些情况下验证失败。 - maverick93
@jonrsharpe 我已经编辑了问题以显示错误。 - maverick93
4个回答

4

你有两个问题,一个是忘记调用dist函数以更新最佳距离。但主要的问题是由于存在多次递归调用,因此可能会使用默认值best,p3,q3 = d,None,None覆盖找到更接近的拆分对时的结果。我将从closest_pair传递的最佳对作为参数传递给closest_split_pair,以避免潜在的值被覆盖。

def closest_split_pair(p_x, p_y, delta, best_pair): # <- a parameter
    ln_x = len(p_x)
    mx_x = p_x[ln_x // 2][0]
    s_y = [x for x in p_y if mx_x - delta <= x[0] <= mx_x + delta]
    best = delta
    for i in range(len(s_y) - 1):
        for j in range(1, min(i + 7, (len(s_y) - i))):
            p, q = s_y[i], s_y[i + j]
            dst = dist(p, q)
            if dst < best:
                best_pair = p, q
                best = dst
    return best_pair

closest_pair 的结尾代码如下:

    p_1, q_1 = closest_pair(srt_q_x, srt_q_y)
    p_2, q_2 = closest_pair(srt_r_x, srt_r_y)
    closest = min(dist(p_1, q_1), dist(p_2, q_2))
    # get min of both and then pass that as an arg to closest_split_pair
    mn = min((p_1, q_1), (p_2, q_2), key=lambda x: dist(x[0], x[1]))
    p_3, q_3 = closest_split_pair(p_x, p_y, closest,mn)
    # either return mn or we have a closer split pair
    return min(mn, (p_3, q_3), key=lambda x: dist(x[0], x[1]))

你还有其他逻辑问题,你的切片逻辑不正确。我对你的代码进行了一些更改,其中brute只是一个简单的暴力双循环:
def closestPair(Px, Py):
    if len(Px) <= 3:
        return brute(Px)

    mid = len(Px) / 2
    # get left and right half of Px 
    q, r = Px[:mid], Px[mid:]
     # sorted versions of q and r by their x and y coordinates 
    Qx, Qy = [x for x in q if Py and  x[0] <= Px[-1][0]], [x for x in q if x[1] <= Py[-1][1]]
    Rx, Ry = [x for x in r if Py and x[0] <= Px[-1][0]], [x for x in r if x[1] <= Py[-1][1]]
    (p1, q1) = closestPair(Qx, Qy)
    (p2, q2) = closestPair(Rx, Ry)
    d = min(dist(p1, p2), dist(p2, q2))
    mn = min((p1, q1), (p2, q2), key=lambda x: dist(x[0], x[1]))
    (p3, q3) = closest_split_pair(Px, Py, d, mn)
    return min(mn, (p3, q3), key=lambda x: dist(x[0], x[1]))

我今天刚刚完成了这个算法,所以肯定还有改进的空间,但是它可以给你正确的答案。


1
我也通过对两个数组q和r进行切片得到了正确的答案,其中q为Qx, Qy = Px[:mid], sorty(Px[:mid]),r为Rx, Ry = Px[mid:], sorty(Px[mid:]),其中sorty()是一个根据y坐标对列表进行排序的函数。感谢您的帮助! - maverick93

2

这里是一个基于堆数据结构的递归分治python实现的最近点问题。它还考虑了负整数。可以使用heappop()从堆中弹出k个节点来返回前k个最近点。

from __future__ import division
from collections import namedtuple
from random import randint
import math as m
import heapq as hq

def get_key(item):
    return(item[0])


def closest_point_problem(points):
    point = []
    heap = []
    pt = namedtuple('pt', 'x y')
    for i in range(len(points)):
        point.append(pt(points[i][0], points[i][1]))
    point = sorted(point, key=get_key)
    visited_index = []
    find_min(0, len(point) - 1, point, heap, visited_index)
    print(hq.heappop(heap))

def find_min(start, end, point, heap, visited_index):
    if len(point[start:end + 1]) & 1:
        mid = start + ((len(point[start:end + 1]) + 1) >> 1)
    else:
        mid = start + (len(point[start:end + 1]) >> 1)
        if start in visited_index:
            start = start + 1
        if end in visited_index:
            end = end - 1
    if len(point[start:end + 1]) > 3:
        if start < mid - 1:
            distance1 = m.sqrt((point[start].x - point[start + 1].x) ** 2 + (point[start].y - point[start + 1].y) ** 2)
            distance2 = m.sqrt((point[mid].x - point[mid - 1].x) ** 2 + (point[mid].y - point[mid - 1].y) ** 2)
            if distance1 < distance2:
                hq.heappush(heap, (distance1, ((point[start].x, point[start].y), (point[start + 1].x, point[start + 1].y))))
            else:
                hq.heappush(heap, (distance2, ((point[mid].x, point[mid].y), (point[mid - 1].x, point[mid - 1].y))))
            visited_index.append(start)
            visited_index.append(start + 1)
            visited_index.append(mid)
            visited_index.append(mid - 1)
            find_min(start, mid, point, heap, visited_index)
        if mid + 1 < end:
            distance1 = m.sqrt((point[mid].x - point[mid + 1].x) ** 2 + (point[mid].y - point[mid + 1].y) ** 2)
            distance2 = m.sqrt((point[end].x - point[end - 1].x) ** 2 + (point[end].y - point[end - 1].y) ** 2)
            if distance1 < distance2:
                hq.heappush(heap, (distance1, ((point[mid].x, point[mid].y), (point[mid + 1].x, point[mid + 1].y))))
            else:
                hq.heappush(heap, (distance2, ((point[end].x, point[end].y), (point[end - 1].x, point[end - 1].y))))
            visited_index.append(end)
            visited_index.append(end - 1)
            visited_index.append(mid)
            visited_index.append(mid + 1)
            find_min(mid, end, point, heap, visited_index)

x = []
num_points = 10
for i in range(num_points):
    x.append((randint(- num_points << 2, num_points << 2), randint(- num_points << 2, num_points << 2)))
closest_point_problem(x)

:)


0

使用stdlib函数可以使暴力算法更快。因此,它可以有效地应用于超过3个点。

from itertools import combinations

def closest(points_list):
    return min((dist(p1, p2), p1, p2)
               for p1, p2 in combinations(points_list, r=2))

最有效的分割点的方法是在网格上进行划分。如果没有离群值,可以将空间均等地划分,并仅在相同或邻近网格中比较点。 网格数量必须尽可能大。但是,为了避免孤立的网格,当每个点在邻近网格中都没有点时,必须通过点数来限制网格数量。 完整列表:
from math import sqrt
from itertools import combinations, product
from collections import defaultdict
import sys

max_float = sys.float_info.max

def dist((x1, y1), (x2, y2)):
    return sqrt((x1 - x2) ** 2 + (y1 - y2) **2)

def closest(points_list):
    if len(points_list) < 2:
        return (max_float, None, None)  # default value compatible with min function
    return min((dist(p1, p2), p1, p2)
               for p1, p2 in combinations(points_list, r=2))

def closest_between(pnt_lst1, pnt_lst2):
    if not pnt_lst1 or not pnt_lst2:
        return (max_float, None, None)  # default value compatible with min function
    return min((dist(p1, p2), p1, p2)
               for p1, p2 in product(pnt_lst1, pnt_lst2))

def divide_on_tiles(points_list):
    side = int(sqrt(len(points_list)))  # number of tiles on one side of square
    tiles = defaultdict(list)
    min_x = min(x for x, y in points_list)
    max_x = max(x for x, y in points_list)
    min_y = min(x for x, y in points_list)
    max_y = max(x for x, y in points_list)
    tile_x_size = float(max_x - min_x) / side
    tile_y_size = float(max_y - min_y) / side
    for x, y in points_list:
        x_tile = int((x - min_x) / tile_x_size)
        y_tile = int((y - min_y) / tile_y_size)
        tiles[(x_tile, y_tile)].append((x, y))
    return tiles

def closest_for_tile(tiles, (x_tile, y_tile)):
    points = tiles[(x_tile, y_tile)]
    return min(closest(points),
               # use dict.get to avoid creating empty tiles
               # we compare current tile only with half of neighbours (right/top),
               # because another half (left/bottom) make it in another iteration by themselves
               closest_between(points, tiles.get((x_tile+1, y_tile))),
               closest_between(points, tiles.get((x_tile, y_tile+1))),
               closest_between(points, tiles.get((x_tile+1, y_tile+1))),
               closest_between(points, tiles.get((x_tile-1, y_tile+1))))

def find_closest_in_tiles(tiles):
    return min(closest_for_tile(tiles, coord) for coord in tiles.keys())


P1 = [(0,0),(7,6),(2,20),(12,5),(16,16),(5,8),(19,7),(14,22),(8,19),(7,29),(10,11),(1,13)]
P2 = [(94, 5), (96, -79), (20, 73), (8, -50), (78, 2), (100, 63), (-14, -69), (99, -8), (-11, -7), (-78, -46)]

print find_closest_in_tiles(divide_on_tiles(P1)) # (2.8284271247461903, (7, 6), (5, 8))
print find_closest_in_tiles(divide_on_tiles(P2)) # (13.92838827718412, (94, 5), (99, -8))
print find_closest_in_tiles(divide_on_tiles(P1 + P2)) # (2.8284271247461903, (7, 6), (5, 8))

-1

你只需要更改closestSplitPair函数def中第七行的best=(Sy[i],Sy[i+j])best=dist(Sy[i],Sy[i+j]),然后你将得到正确的答案: ((94, 5), (99, -8), 13.92838827718412)。你忘记调用dist函数了。

这是Padraic Cunningham的答案指出的第一个问题。

此致敬礼。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接