在Python中,给定一个阈值,高效地删除彼此相近的数组。

5

我是用Python来完成这项工作的,很客观地说,我想找到一种“Pythonic”的方法来从一个数组中删除那些彼此之间距离小于阈值的“重复”项。例如,给定以下数组:

[[ 5.024,  1.559,  0.281], [ 6.198,  4.827,  1.653], [ 6.199,  4.828,  1.653]]

请注意,[6.198, 4.827, 1.653][6.199, 4.828, 1.653]非常接近,它们的欧几里得距离为0.0014,因此它们几乎是“重复”的。我希望我的最终输出只有:
[[ 5.024,  1.559,  0.281], [ 6.198,  4.827,  1.653]]

我现在拥有的算法是:
to_delete = [];
for i in unique_cluster_centers:
    for ii in unique_cluster_centers:
        if i == ii:
            pass;
        elif np.linalg.norm(np.array(i) - np.array(ii)) <= self.tolerance:
            to_delete.append(ii);
            break;

for i in to_delete:
    try:
        uniques.remove(i);
    except:
        pass;

但是它真的很慢,我想知道一些更快和“Pythonic”的解决方法。我的容忍度为0.0001。


np.array(i) 应该是什么意思?我认为它并不是在真正的 Python/numpy 脚本中所产生的结果。 - hpaulj
scipy.spatial.distance有成对距离函数。 - hpaulj
2个回答

2

一种通用方法可能是:

def filter_quadratic(data,condition):
    result = []
    for element in data:
        if all(condition(element,other) for other in result):
            result.append(element)
    return result

这是一个通用的高阶filter,它具有条件。只有当这个条件对于已经在列表中的所有元素都被满足时,才会添加该元素。
现在我们仍然需要定义条件:
def the_condition(xs,ys):
    # working with squares, 2.5e-05 is 0.005*0.005 
    return sum((x-y)*(x-y) for x,y in zip(xs,ys)) > 2.5e-05

这将会得到:
>>> filter_quadratic([[ 5.024,  1.559,  0.281], [ 6.198,  4.827,  1.653], [ 6.199,  4.828,  1.653]],the_condition)
[[5.024, 1.559, 0.281], [6.198, 4.827, 1.653]]

该算法的时间复杂度为 O(n2),其中 n 是传递给函数的元素数量。但你可以使用 k-d 树使其更加高效,不过这需要一些更高级的数据结构。

我会尝试这个,我避免使用k-d树。我更喜欢从头开始写。 - Pj-

1
如果您能避免在嵌套循环中将每个列表元素与其他元素进行比较(这是不可避免的O(n ^ 2)操作),那么效率会更高。

一种方法是生成一个键,使得两个“几乎重复”的键将产生相同的键。然后,您只需对数据进行一次迭代,并仅插入尚未在结果集中的值。

result = {}
for row in unique_cluster_centers:
    # round each value to 2 decimal places: 
    # [5.024,  1.559,  0.281] => (5.02,  1.56,  0.28)
    # you can be inventive and, say, multiply each value by 3 before rounding
    # if you want precision other than a whole decimal point.
    key = tuple([round(v, 2) for v in row])  # tuples can be keys of a dict
    if key not in result:
        result[key] = row
return result.values()  # I suppose the order of the items is not important, you can use OrderedDict otherwise  

要使其正确,您仍需要在相邻的块之间进行比较(例如1.004999、1.000000、0.00000和1.005001、1.000000、0.000000在您的方案中将具有不同的键)。 - Paul Panzer
嗯,是的,你说得对 :) 我在想是否可以通过对数据集进行三次遍历来解决——另外两个遍历将使用基于值的键,这些键比原始键大或小半个桶尺寸。这将捕获“桶”边缘处的这些值。三个顺序遍历仍然比嵌套循环快得多。 - Sergey
再想一想 - 不行,我之前评论的那个想法行不通 :( 例如:0.9和1.89四舍五入为1。 - Sergey

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接