快速的numpy滚动乘积

5
我需要一个滚动乘积函数或扩展乘积函数。
有各种pandas的rolling_XXXX和expanding_XXXX函数,但我惊讶地发现没有expanding_product()函数。
为了让事情正常运作,我一直在使用这个相对较慢的替代方法。
pd.expanding_apply(temp_col, lambda x : x.prod())

我的数组通常有32,000个元素,因此这证明是一个瓶颈。我想尝试使用log()cumsum()exp(),但我认为应该在这里询问,因为可能有更好的解决方案。


"有各种不同的numpy rolling_XXXX" - 你确定你是指“numpy”而不是“pandas”吗? - Ami Tavory
2
对于扩展产品,可以使用 cumprod()。对于滚动版本,我认为您需要使用 rolling_applyprod() 应用于每个窗口。 - Alex Riley
@JasonEdinburgh "log(),cumsum()和exp()" - 你是指log、rolling_mean和exp吗? - Ami Tavory
1
说到疲劳,这个页面可能会帮助您确定日志、滚动总和、指数方案的数值稳定性,但我太累了,无法仔细阅读。祝你好运。 - Ami Tavory
@AmiTavory 您说得很对。我认为expanding_product不需要执行重复的除法,但rolling_product肯定需要,这可能是为什么它被省略的原因。我刚试了np.exp(pd.expanding_sum(np.log(temp_col))),目前速度足够快,并且似乎给出的结果与rolling_apply版本相差不到0.00001。如果我在分析中看到它出现,那么我将尝试下一个numba/cython版本。感谢您的帮助 :) - JasonEdinburgh
显示剩余4条评论
2个回答

6

我有一个更快的机制,不过你需要运行一些测试来确认准确性是否足够。

这是原始的指数/求和/对数版本:

def rolling_prod1(xs, n):
    return np.exp(pd.rolling_sum(np.log(xs), n))

这里有一种方法可以对累积乘积进行处理,先将其向左移动(填充为NaN),然后再除以原来的数值。

def rolling_prod2(xs, n):
    cxs = np.cumprod(xs)
    nans = np.empty(n)
    nans[:] = np.nan
    nans[n-1] = 1.
    a = np.concatenate((nans, cxs[:len(cxs)-n]))
    return cxs / a

这个例子中,两个函数返回的结果相同:

In [9]: xs
Out[9]: array([ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9.])

In [10]: rolling_prod1(xs, 3)
Out[10]: array([  nan,   nan,    6.,   24.,   60.,  120.,  210.,  336.,  504.])

In [11]: rolling_prod2(xs, 3)
Out[11]: array([  nan,   nan,    6.,   24.,   60.,  120.,  210.,  336.,  504.])

但是第二个版本要快得多:

In [12]: temp_col = np.random.rand(30000)

In [13]: %timeit rolling_prod1(temp_col, 3)
1000 loops, best of 3: 694 µs per loop

In [14]: %timeit rolling_prod2(temp_col, 3)
10000 loops, best of 3: 162 µs per loop

1
当我第一次阅读这篇文章时,我感到很困惑,因为我认为“但是没有numpy cumprod函数”。显然,昨晚我非常疲倦。我发现实际上有一个numpy cumprod函数,但在昨晚搜索它时,我不知道为什么没能找到它!既然我只需要一个expanding_prod函数,np.cumprod就是我要找的。但我真的很喜欢你采用的窗口和单除法方法来进行滚动版本。所以尽管我因未能找到numpy.cumprod而感到尴尬,但我会把这篇文章留在这里,以防您的解决方案对其他人有用。谢谢! - JasonEdinburgh

2
早期结果表明,这是扩展产品的相对快速的近似方法。
np.exp(pd.expanding_sum(np.log(temp_col)))

rolling_product需要重复除法,这可能会导致数值不稳定(正如@AmiTavory在现已删除的答案中指出的那样)


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接