我有一个形状为 (6, 2, 4)
的 numpy 数组:
x = np.array([[[0, 3, 2, 0],
[1, 3, 1, 1]],
[[3, 2, 3, 3],
[0, 3, 2, 0]],
[[1, 0, 3, 1],
[3, 2, 3, 3]],
[[0, 3, 2, 0],
[1, 3, 2, 2]],
[[3, 0, 3, 1],
[1, 0, 1, 1]],
[[1, 3, 1, 1],
[3, 1, 3, 3]]])
我有一个名为choices
的数组,内容如下:
choices = np.array([[1, 1, 1, 1],
[0, 1, 1, 0],
[1, 1, 1, 1],
[1, 0, 0, 0],
[1, 0, 1, 1],
[0, 0, 0, 1]])
如何最有效地使用 choices
数组只索引大小为 2 的中间维度,并获得一个形状为 (6, 4)
的新 numpy 数组?
结果应该是这样的:
[[1 3 1 1]
[3 3 2 3]
[3 2 3 3]
[1 3 2 0]
[1 0 1 1]
[1 3 1 3]]
我试过使用
x[:, choices, :]
,但这并没有返回我想要的结果。我还尝试了 x.take(choices, axis=1)
,但没有成功。
choices[:, None]
等于choices[:, None, :]
需要一些时间,这将导致一个形状为[6, 1, 4]
的数组,可以与x
广播! - wstcegg