复制一个NumPy数组列表

4

我正在使用(numpy数组的)列表的列表。以下是一个基本示例代码:

a = [np.zeros(5)]
b = a.copy()
b[0] += 1

这里,我将从a复制一个数组的列表到b。然而,数组本身并没有被复制,因此:

print(a)
print(b)

两者都会得到[array([1., 1., 1., 1., 1.])]。如果我也想复制该数组,可以执行以下操作:

b = [arr.copy() for arr in a]

如果一个列表只含有简单的元素,使用浅拷贝可以实现不改变原始列表及其元素的目的,比如该列表中的 a 元素。但是当处理嵌套多维数组列表时,每个子列表中包含的数组数量不一定相同,则需要考虑更复杂的情况。

有没有一种简单的方法可以复制多级嵌套列表及其所有包含的对象而不保留对原始列表中对象的引用?基本上,我想避免嵌套循环以及处理每个子列表的大小。

2个回答

3
你要找的是深拷贝。
import numpy as np
import copy
a = [np.zeros(5)]
b = copy.deepcopy(a)
b[0] += 1  # a[0] is not changed

这实际上是numpy文档推荐的用于深拷贝object数组的方法。


正是我所需要的!我猜这个解决方案也可以同样适用于包含其他对象(比如数组)的自定义类吧? - JD80121
确实应该这样做,关于自定义对象的深拷贝替代方案,请参考此处:https://dev59.com/U5nga4cB1Zd3GeqPgPvp - Mederic Fourmy

1

您需要使用深拷贝(deepcopy)。

import numpy as np
import copy

a = [np.zeros(5)]
b = copy.deepcopy(a)
b[0] += 1

print(a)
print(b)

结果:

[array([0., 0., 0., 0., 0.])]
[array([1., 1., 1., 1., 1.])]

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接