Numba输出的差异

4
我在研究工作中实现了一种基本的最近邻搜索算法。 事实上,基本的numpy实现效果很好,但只需添加'@jit'装饰器(使用Numba编译),输出结果就有所不同(最后出现某些邻居的重复,原因未知...)
以下是基本算法:
import numpy as np
from numba import jit

@jit(nopython=True)
def knn(p, points, k):
    '''Find the k nearest neighbors (brute force) of the point p
    in the list points (each row is a point)'''

    n = p.size  # Lenght of the points
    M = points.shape[0]  # Number of points
    neighbors = np.zeros((k,n))
    distances = 1e6*np.ones(k)

    for i in xrange(M):
        d = 0
        pt = points[i, :]  # Point to compare
        for r in xrange(n):  # For each coordinate
            aux = p[r] - pt[r]
            d += aux * aux
        if d < distances[k-1]:  # We find a new neighbor
            pos = k-1
            while pos>0 and d<distances[pos-1]:  # Find the position
                pos -= 1
            pt = points[i, :]
            # Insert neighbor and distance:
            neighbors[pos+1:, :] = neighbors[pos:-1, :]
            neighbors[pos, :] = pt
            distances[pos+1:] = distances[pos:-1]
            distances[pos] = d

    return neighbors, distances

进行测试:

p = np.random.rand(10)
points = np.random.rand(250, 10)
k = 5
neighbors = knn(p, points, k)

没有使用@jit装饰器,将得到正确的答案:

In [1]: distances
Out[1]: array([ 0.3933974 ,  0.44754336,  0.54548715,  0.55619749,  0.5657846 ])

但是Numba编译会产生奇怪的输出:
Out[2]: distances
Out[2]: array([ 0.3933974 ,  0.44754336,  0.54548715,  0.54548715,  0.54548715])

有人可以帮忙吗?我不明白为什么会出现这种情况...谢谢。

你可能会对scipy的KDTree实现感兴趣。 - Daniel
@Ophion 谢谢你的建议。我一直在使用sklearn的KDTree实现(我想它们很相似),它们非常适合预处理数据以供未来多个查询点使用。在我的工作中,我需要不断更改点列表以查找邻居(在图像处理方面),而这种类型的实现变得太慢了。当空间维度较大时(例如大于25),KDTree实现似乎并不比暴力搜索更好。 - Mario González
1个回答

1
我认为问题在于Numba在处理重叠的切片时,与不重叠的情况下有所不同。我不熟悉numpy的内部机制,但也许有一些特殊的逻辑来处理像这样的易变内存操作,在Numba中不存在。更改以下行并使用jit修饰符的结果将与普通的Python版本一致:
neighbors[pos+1:, :] = neighbors[pos:-1, :].copy()
...
distances[pos+1:] = distances[pos:-1].copy() 

感谢@JoshAdel!这对我很有效。我之前验证过,在Numpy中这种重叠的切片不会引起问题,但由于某些原因,Numba将其翻译成了不同的算法...无论如何,奇怪的是它只复制了一些邻居而不是其他的...再次感谢!附:我是Python的粉丝,但像这样的事情让我认真考虑学习Julia... - Mario González
@MarioGonzález 我建议你在Numba的github问题跟踪器上发布你的示例。开发团队通常非常敏感,并且希望了解错误或意外行为。 - JoshAdel
感谢@JoshAdel的建议。它发布在这里 - Mario González

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接