将一个numpy数组分成两个不同大小的子集

5
我有一个形状为 (400, 3, 3, 3) 的numpy数组,我想将其分成两部分,以便得到类似 (100, 3, 3, 3)(300, 3, 3, 3) 的数组。
我已经尝试使用numpy的 split 方法,例如:
subsets = np.array_split(arr, 2)

这个函数给了我想要的结果,但它把原数组分成了两半,大小相同,而我不知道如何指定这些大小。也许用一些索引会很容易(我猜是这样),但我不确定该怎么做。


2
x, y = arr[:100, ...], arr[100:, ...] 应该可以... - cs95
1
我会像@cᴏʟᴅsᴘᴇᴇᴅ建议的那样使用切片符号,因为这样可能会占用更少的内存(因为数组会共享底层缓冲区)。不确定这是否适用于split,但如果必须这样做,您可以执行subsets = np.array_spit(arr, [100]) - juanpa.arrivillaga
切片符号正是我正在寻找的,谢谢。 - T.Poe
1个回答

6

正如我在评论中提到的那样,您可以使用省略号符号来指定所有轴:

x, y = arr[:100, ...], arr[100:, ...]

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接