如何提高这个小距离 Python 函数的性能?

5

当我在使用sklearn聚类算法时,使用自定义距离度量函数遇到了性能瓶颈。

通过Run Snake Run显示的结果如下:

enter image description here

很明显问题出现在dbscan_metric函数中。这个函数看起来非常简单,我并不知道加速它的最佳方法是什么:

def dbscan_metric(a,b):
  if a.shape[0] != NUM_FEATURES:
    return np.linalg.norm(a-b)
  else:
    return np.linalg.norm(np.multiply(FTR_WEIGHTS, (a-b)))

任何关于导致速度缓慢的想法,将不胜感激。

1
答案取决于数组a和b的大小。此外,您使用的numpy版本是哪个?鉴于它们的配置文件可能很小,并且您受到Python开销的支配,如果是这种情况,您需要使用Cython来减少开销。 - jtaylor
1
看起来在norm@linalg.py中花费了71秒,而在其他地方花费了170秒,这似乎很奇怪。我以为我知道如何理解这个图表,但这似乎很奇怪。我唯一能猜测的是额外的170秒与调用开销有关。你可以尝试内联吗? - user1245262
1
NUM_FEATURES和FTR_WEIGHTS的局部范围外查找也可能会花费一些时间。 - sirlark
@sirlark:更新:将这两个常量移动到本地作用域也没有什么帮助。 - houbysoft
你可以尝试这种方法。这样你就可以找出是否不必要地调用了低级函数。 - Mike Dunlavey
显示剩余10条评论
1个回答

1

我不熟悉这个函数的作用 - 但是是否存在重复计算的可能性?如果是,您可以使用记忆化技术来优化该函数:

cache = {}
def dbscan_metric(a,b):

  diff = a - b

  if a.shape[0] != NUM_FEATURES:
    to_calc = diff
  else:
    to_calc = np.multiply(FTR_WEIGHTS, diff)

  if not cache.get(to_calc): cache[to_calc] = np.linalg.norm(to_calc)

  return cache[to_calc]

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接