NumPy - 计算直方图交集

8
以下数据表示将2个给定的直方图分成13个箱子:
key 0   1-9 10-18   19-27   28-36   37-45   46-54   55-63   64-72   73-81   82-90   91-99   100
A   1.274580708 2.466224824 5.045757621 7.413716262 8.958855646 10.41325305 11.14150951 10.91949012 11.29095648 10.95054297 10.10976255 8.128781795 1.886568472
B   0   1.700493692 4.059243006 5.320899616 6.747120132 7.899067471 9.434997257 11.24520022 12.94569391 12.83598464 12.6165661  10.80636314 4.388370817

我正在尝试按照这篇文章中的方法计算这两个直方图之间的交集:

enter image description here

请注意,本文中的HTML标记已保留。
def histogram_intersection(h1, h2, bins):
   bins = numpy.diff(bins)
   sm = 0
   for i in range(len(bins)):
       sm += min(bins[i]*h1[i], bins[i]*h2[i])
   return sm

由于我的数据已经被计算成直方图形式,我不能使用numpy内置的函数,因此我无法为函数提供必要的数据。

我该如何处理我的数据以适应算法?


你展示的函数有什么问题吗?乍一看看起来还不错。 - Benjamin
该函数期望由直方图方法生成的箱子。我只有直方图数据。 - Shlomi Schwartz
1
既然A和B的bin相同,直接使用np.minimum(A, B)怎么样? - xdze2
谢谢回复,我会尝试一下。 - Shlomi Schwartz
3个回答

7

既然两个直方图的箱子数相同,因此可以使用以下方法:

def histogram_intersection(h1, h2):
    sm = 0
    for i in range(13):
        sm += min(h1[i], h2[i])
    return sm

3

您可以使用Numpy更快、更简单地计算:

#!/usr/bin/env python3

import numpy as np

A = np.array([1.274580708,2.466224824,5.045757621,7.413716262,8.958855646,10.41325305,11.14150951,10.91949012,11.29095648,10.95054297,10.10976255,8.128781795,1.886568472])
B = np.array([0,1.700493692,4.059243006,5.320899616,6.747120132,7.899067471,9.434997257,11.24520022,12.94569391,12.83598464,12.6165661,10.80636314,4.388370817])

def histogram_intersection(h1, h2):
    sm = 0
    for i in range(13):
        sm += min(h1[i], h2[i])
    return sm

print(histogram_intersection(A,B))
print(np.sum(np.minimum(A,B)))

输出

88.44792356099998
88.447923561

但是如果您计时,Numpy 只需要 60% 的时间:

%timeit histogram_intersection(A,B)
5.02 µs ± 65.3 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

%timeit np.sum(np.minimum(A,B))
3.22 µs ± 11.3 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

我正在尝试比较两个颜色直方图,但出现了以下错误:VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray. return np.sum(np.minimum(hist_1,hist_2)) - NccWarp9

1
首先需要注意:在您的数据中,区间是以范围形式存在的,而在算法中它们是以数字形式存在的。因此,您必须重新定义区间。
此外,min(bins[i]*h1[i], bins[i]*h2[i])等同于bins[i]*min(h1[i], h2[i]),因此可以通过以下方式获得结果:
hists=pandas.read_clipboard(index_col=0) # your data
bins=arange(-4,112,9)   #  try for bins but edges are different here
mins=hists.min('rows')
intersection=dot(mins,bins) 

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接