使用numpy缩减一个轴

3
我有一个NxMx3的numpy数组,其dtype=object。我还有一个函数f(a,b,c),它接受该数组中最后一个轴上的三个元素,并返回np.int32。我的问题是如何将f应用于我的NxMx3数组,以产生一个dtype=np.int32NxM数组?
我的当前解决方案是使用
newarr = np.fromfunction(lambda i,j: f(arr[i,j,0], arr[i,j,1], arr[i,j,2]),
                          arr.shape[:2], dtype=np.int)

虽然这比我希望的要冗长一些,但仍需如此。
1个回答

3
你可以使用 vectorize 函数:
np.vectorize(f, otypes=[np.int32])(arr[:, :, 0], arr[:, :, 1], arr[:, :, 2])

这可以通过轴滚动和迭代来简化:
np.vectorize(f, otypes=[np.int32])(*np.rollaxis(arr, 2, 0))

或者,您可以使用dsplit显式地拆分数组:

np.vectorize(f, otypes=[np.int32])(*np.dsplit(arr, 3))[..., 0]

或者

np.vectorize(f, otypes=[np.int32])(*np.dsplit(arr, 3)).reshape(arr.shape[:-1])

或者

np.vectorize(f, otypes=[np.int32])(*np.dsplit(arr, 3)).squeeze()

然而,apply_along_axis 可能更简单:
np.apply_along_axis(lambda x: f(*x), 2, arr)

生气 +1 -- 我以为我会比你更快地使用 apply_along_axis。 :^) - DSM

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接