另一种选择是多维位置列表索引:
import numpy as np
ncol = 10
nrow = 500
x = np.arange(ncol*nrow).reshape(nrow,ncol)
y = (ncol * np.random.random_sample((nrow, 1))).astype(int)
print(x)
print(y)
print(x[np.arange(nrow),y.T].T)
这里解释了语法。你需要为每个维度创建一个索引数组。在第一维中,这只是[0,...,500],而第二维是您的y数组。我们需要转置它(.T),因为它必须具有与第一个和输出数组相同的形状。第二次转置实际上并不需要,但可以给你想要的形状。
编辑:
性能问题出现了,我尝试了迄今为止提到的三种方法。您需要line_profiler来运行以下内容。
kernprof -l -v tmp.py
tmp.py 是指:
import numpy as np
@profile
def calc(x,y):
z = np.arange(nrow)
a = x[z,y.T].T
b = x[:,y].diagonal().T
c = np.array([i[j] for i, j in zip(x, y)])
return (a,b,c)
ncol = 5
nrow = 10
x = np.arange(ncol*nrow).reshape(nrow,ncol)
y = (ncol * np.random.random_sample((nrow, 1))).astype(int)
a, b, c = calc(x,y)
print(a==b)
print(b==c)
我的 Python 2.7.6 的输出结果:
Line
==============================================================
3 @profile
4 def calc(x,y):
5 1 4 4.0 0.1 z = np.arange(nrow)
6 1 35 35.0 0.8 a = x[z,y.T].T
7 1 3409 3409.0 76.7 b = x[:,y].diagonal().T
8 501 995 2.0 22.4 c = np.array([i[j] for i, j in zip(x, y)])
9
10 1 1 1.0 0.0 return (a,b,c)
%Time或Time是相关列。我不知道如何分析内存消耗,需要别人来完成。目前看来,我的解决方案在所请求的维度中是最快的。
numpy
吗?像这样的东西基本上就是它的用途。 - Jan Christoph TerasaX
和Y
可能的具体(但截短的)示例,并说明您期望的输出是什么? - ymbirtty(3)
很可能意味着第4行。 - martineau