计算点列表之间的距离的Pythonic方式

4

我有一个无序点列表(2D),想要计算它们之间距离的总和。 由于我的背景是C++开发者,我会像这样做:

import math

class Point:
    def __init__(self, x,y):
        self.x = x
        self.y = y

def distance(P1, P2):
    return math.sqrt((P2.x-P1.x)**2 + (P2.y-P1.y)**2)

points = [Point(rand(1), rand(1)) for i in range(10)]

#this part should be in a nicer way
pathLen = 0
for i in range(1,10):
    pathLen += distance(points[i-1], points[i])

有没有更符合Python风格的方式来替换for循环?例如使用reduce或类似的东西?

最好的问候!


所有点对之间的距离?还是像点之间的单一路径? - Igl3
4
注意:你写成了 math.sqrt((P1.x+P2.x)**2 ...。距离应该使用 - 而不是 **+**。期望 distance(P, P) 的值为0! - Serge Ballesta
谢谢,已更改。没注意到数学方面 :D - b2aff6009
最干净和高效的方法是使用numpy,它提供了一些方法,允许以非常类似matlab的语法编写代码,并使用优化的例程同时计算许多距离。 - allo
3个回答

2

您可以使用生成器表达式与 sumzip 和 itertools 的 islice 来避免重复数据:

from itertools import islice
paathLen = sum(distance(x, y) for x, y in zip(points, islice(points, 1, None)))

这里有一个实时示例


1
这似乎比zip示例快10%。 - Dominik Stańczak
1
很好,我之前没有遇到过islice,显然这才是我实际上在答案中寻找的,而不是我笨拙的切片。 - Robin Zigmond

1

有一些修复,因为在这里使用C++方法可能不是最好的:

import math
# you need this import here, python has no rand in the main namespace
from random import random

class Point:
    def __init__(self, x,y):
        self.x = x
        self.y = y
    # there's usually no need to encapsulate variables in Python

def distance(P1, P2):
    # your distance formula was wrong 
    # you were adding positions on each axis instead of subtracting them
    return math.sqrt((P1.x-P2.x)**2 + (P1.y-P2.y)**2)

points = [Point(random(), random()) for i in range(10)]
# use a sum over a list comprehension:
pathLen = sum([distance(points[i-1], points[i]) for i in range(10)])

@Robin Zigmond的`zip`方法也是实现这一目标的一种简洁方式,但我一开始并没有意识到它可以在这里使用。

在我看来,使用单个迭代器并在此处进行压缩要更加清晰简单。+1 - Serge Ballesta

0

我遇到了类似的问题,并拼凑出了一个numpy解决方案,我认为它很好用。

也就是说,如果你将点的列表转换为numpy数组,然后可以执行以下操作:

pts = np.asarray(points)
dist = np.sqrt(np.sum((pts[np.newaxis, :, :]-pts[:, np.newaxis, :])**2, axis=2))

dist会生成一个nxn的numpy对称数组,其中每个点到其他每个点的距离都在对角线上方或下方给出。对角线是每个点到自身的距离,因此只有0。

然后您可以使用:

path_leng = np.sum(dist[np.triu_indices(pts.shape[0], 1)].tolist())

收集numpy数组的上半部分并将它们相加以获取路径长度。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接