从Numpy向量数组中删除(在给定容差内)的重复项

5
我有一个Nx5的数组,包含N个形式为'id'、'x'、'y'、'z'和'energy'的向量。我需要在容差为0.1的情况下删除重复点(即x、y、z完全匹配的点)。理想情况下,我可以创建一个函数,在其中传递数组、需要匹配的列以及匹配的容差。
根据Scipy-user上的此线程,我可以使用记录数组基于完整数组删除重复项,但我需要匹配部分数组。此外,这将无法在一定的容差范围内匹配。
我可以在Python中使用for循环费力地迭代,但是否有更好的Numponic方法?

1
你提供的规格存在内在问题,这就是为什么你很难找到现成的解决方案:比如说,为了清晰起见,公差实际上是0.11,y和z始终相同,而“x”是0、0.1、0.2、0.3、0.4等等——现在,“重复项”是什么?按照你的定义,0.1是0和0.2的“重复项”,但这两个数字彼此之间并不是“重复项”——因此,“重复项”关系不具有传递性,因此无法引出一个分区!你需要自己定义一些启发式方法,因为没有真正的“数学上正确”的解决方案(不可能有:没有分区!-)。 - Alex Martelli
1
我理解你的观点。但在我的工作领域中,我希望进行聚类,即簇内点之间的平均距离应该约等于公差,而簇间平均距离 >> 簇内点之间的平均距离。公差的大小应该足够小,以便于你的目的中任何一个簇内的点都可以成为“基准”点。 - Brendan
3个回答

2
您可以参考 scipy.spatial.KDTree。 N 有多大? 添加:哎呀,tree.query_pairs 不在 scipy 0.7.1 中。
如果不确定,请使用暴力算法:将空间(这里是 side^3)分割成小单元格,每个单元格一个点。
""" scatter points to little cells, 1 per cell """
from __future__ import division         
import sys                              
import numpy as np                      

side = 100                              
npercell = 1  # 1: ~ 1/e empty          
exec "\n".join( sys.argv[1:] )  # side= ...
N = side**3 * npercell                  
print "side: %d  npercell: %d  N: %d" % (side, npercell, N)
np.random.seed( 1 )                     
points = np.random.uniform( 0, side, size=(N,3) )

cells = np.zeros( (side,side,side), dtype=np.uint )
id = 1
for p in points.astype(int):
    cells[tuple(p)] = id                
    id += 1                             

cells = cells.flatten()
    # A C, an E-flat, and a G walk into a bar. 
    # The bartender says, "Sorry, but we don't serve minors."
nz = np.nonzero(cells)[0]               
print "%d cells have points" % len(nz)
print "first few ids:", cells[nz][:10]

使用KDTree是个好主意,我可能稍后会实现它。 - Brendan

1

我终于找到了一个令我满意的解决方案,这是从我的代码中稍微整理过的剪贴板。可能还存在一些错误。

请注意:它仍然使用“for”循环。我可以使用Denis上面提出的KDTree的想法加上四舍五入来获得完整的解决方案。

import numpy as np

def remove_duplicates(data, dp_tol=None, cols=None, sort_by=None):
    '''
    Removes duplicate vectors from a list of data points
    Parameters:
        data        An MxN array of N vectors of dimension M 
        cols        An iterable of the columns that must match 
                    in order to constitute a duplicate 
                    (default: [1,2,3] for typical Klist data array) 
        dp_tol      An iterable of three tolerances or a single 
                    tolerance for all dimensions. Uses this to round 
                    the values to specified number of decimal places 
                    before performing the removal. 
                    (default: None)
        sort_by     An iterable of columns to sort by (default: [0])

    Returns:
        MxI Array   An array of I vectors (minus the 
                    duplicates)

    EXAMPLES:

    Remove a duplicate

    >>> import wien2k.utils
    >>> import numpy as np
    >>> vecs1 = np.array([[1, 0, 0, 0],
    ...     [2, 0, 0, 0],
    ...     [3, 0, 0, 1]])
    >>> remove_duplicates(vecs1)
    array([[1, 0, 0, 0],
           [3, 0, 0, 1]])

    Remove duplicates with a tolerance

    >>> vecs2 = np.array([[1, 0, 0, 0  ],
    ...     [2, 0, 0, 0.001 ],
    ...     [3, 0, 0, 0.02  ],
    ...     [4, 0, 0, 1     ]])
    >>> remove_duplicates(vecs2, dp_tol=2)
    array([[ 1.  ,  0.  ,  0.  ,  0.  ],
           [ 3.  ,  0.  ,  0.  ,  0.02],
           [ 4.  ,  0.  ,  0.  ,  1.  ]])

    Remove duplicates and sort by k values

    >>> vecs3 = np.array([[1, 0, 0, 0],
    ...     [2, 0, 0, 2],
    ...     [3, 0, 0, 0],
    ...     [4, 0, 0, 1]])
    >>> remove_duplicates(vecs3, sort_by=[3])
    array([[1, 0, 0, 0],
           [4, 0, 0, 1],
           [2, 0, 0, 2]])

    Change the columns that constitute a duplicate

    >>> vecs4 = np.array([[1, 0, 0, 0],
    ...     [2, 0, 0, 2],
    ...     [1, 0, 0, 0],
    ...     [4, 0, 0, 1]])
    >>> remove_duplicates(vecs4, cols=[0])
    array([[1, 0, 0, 0],
           [2, 0, 0, 2],
           [4, 0, 0, 1]])

    '''
    # Deal with the parameters
    if sort_by is None:
        sort_by = [0]
    if cols is None:
        cols = [1,2,3]
    if dp_tol is not None:
        # test to see if already an iterable
        try:
            null = iter(dp_tol)
            tols = np.array(dp_tol)
        except TypeError:
            tols = np.ones_like(cols) * dp_tol
        # Convert to numbers of decimal places
        # Find the 'order' of the axes
    else:
        tols = None

    rnd_data = data.copy()
    # set the tolerances
    if tols is not None:
        for col,tol in zip(cols, tols):
            rnd_data[:,col] = np.around(rnd_data[:,col], decimals=tol)

    # TODO: For now, use a slow Python 'for' loop, try to find a more
    # numponic way later - see: https://dev59.com/YUzSa4cB1Zd3GeqPnYim
    sorted_indexes = np.lexsort(tuple([rnd_data[:,col] for col in cols]))
    rnd_data = rnd_data[sorted_indexes]
    unique_kpts = []
    for i in xrange(len(rnd_data)):
        if i == 0:
            unique_kpts.append(i)    
        else:
            if (rnd_data[i, cols] == rnd_data[i-1, cols]).all():
                continue
            else:
                unique_kpts.append(i)    

    rnd_data =  rnd_data[unique_kpts]
    # Now sort
    sorted_indexes = np.lexsort(tuple([rnd_data[:,col] for col in sort_by]))
    rnd_data = rnd_data[sorted_indexes]
    return rnd_data



if __name__ == '__main__':
    import doctest
    doctest.testmod()

0

我没有测试过,但如果你按照x、y、z的顺序对数组进行排序,这应该会给你重复项列表。然后你需要选择要保留哪些。

def find_dup_xyz(anarray, x, y, z): #for example in an data = array([id,x,y,z,energy]) x=1 y=2 z=3
    dup_xyz=[]
    for i, row in enumerated(sortedArray):
        nx=1
        while (abs(row[x] - sortedArray[i+nx[x])<0.1) and (abs(row[z] and sortedArray[i+nx[y])<0.1) and (abs(row[z] - sortedArray[i+nx[z])<0.1):
              nx=+1
              dup_xyz.append(row)
return dup_xyz

还发现了这个: http://mail.scipy.org/pipermail/scipy-user/2008-April/016504.html


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接