NumPy数组写时复制

5
我有一个类,返回大型NumPy数组。这些数组在类内缓存。我希望返回的数组是写时复制(copy-on-write)数组。如果调用者最终只是从数组中读取,就不会生成副本。这将不会使用额外的内存。但是,该数组是“可修改”的,但不会修改内部缓存数组。
目前我的解决方案是使任何缓存数组为只读状态 (a.flags.writeable = False)。这意味着,如果函数的调用者需要修改它,他们必须自己复制该数组。当然,如果源不是来自缓存且数组已经可写,则不必要地复制数据。
因此,最理想的情况是像这样使用一些东西 a.view(flag=copy_on_write)。似乎有一个标志是反其道而行之的 UPDATEIFCOPY,它会导致一次复制在被释放后更新原始数据。
谢谢!
2个回答

5

复制-写时是一个不错的概念,但显式复制似乎是 "NumPy哲学"。所以个人认为,如果不太笨拙,我会保留 "只读" 解决方案。

但我承认自己编写了自己的复制-写时封装类。 我不尝试检测对数组的写入访问。而是该类具有一个方法 "get_array(readonly)",返回其(否则私有的)numpy数组。第一次用 "readonly=False" 调用它时,它会进行复制。这非常明确,易于阅读和理解。

如果你的复制-写时numpy数组看起来像传统的numpy数组,你的代码读者(可能是2年后的你)可能会很难理解。


1
我采用了这种方法,唯一的例外是它总是只读的,如果调用者想要读写,他们可以自己复制。 - coderforlife

4
为了实现写时复制,我们需要修改ndarray对象的basedatastrides。我认为这不能在纯Python代码中完成。我使用了一些Cython代码来修改这些属性。
以下是IPython笔记本中的代码:
%load_ext cythonmagic

使用Cython定义copy_view():

%%cython
cimport numpy as np

np.import_array()
np.import_ufunc()

def copy_view(np.ndarray a):
    cdef np.ndarray b
    cdef object base
    cdef int i
    base = np.get_array_base(a)
    if base is None or isinstance(base, a.__class__):
        return a
    else:
        print "copy"
        b = a.copy()
        np.set_array_base(a, b)
        a.data = b.data
        for i in range(b.ndim):
            a.strides[i] = b.strides[i]

定义一个ndarray的子类:

class cowarray(np.ndarray):
    def __setitem__(self, key, value):
        copy_view(self)
        np.ndarray.__setitem__(self, key, value)

    def __array_prepare__(self, array, context=None):
        if self is array:
            copy_view(self)
        return array

    def __array__(self):
        copy_view(self)
        return self

一些测试:

a = np.array([1.0, 2, 3, 4])
b = a.view(cowarray)
b[1] = 100 #copy 
print a, b
b[2] = 200 #no copy
print a, b

c = a[::2].view(cowarray)
c[0] = 1000 #copy
print a, c

d = a.view(cowarray)
np.sin(d, d) #copy
print a, d           

输出结果:

copy
[ 1.  2.  3.  4.] [   1.  100.    3.    4.]
[ 1.  2.  3.  4.] [   1.  100.  200.    4.]
copy
[ 1.  2.  3.  4.] [ 1000.     3.]
copy
[ 1.  2.  3.  4.] [ 0.84147098  0.90929743  0.14112001 -0.7568025 ]

1
我已经尝试过这个,看起来相当不错。然而还有一些问题。首先,在Cython代码中的if语句之前,我必须添加一个while循环来不断搜索基础,直到找到匹配项或None为止,否则像b=a.view(cowarray); c=b[:2]; c[0]=1000;这样做将无法复制。但可能还有其他问题,例如使用循环修复可能会产生比必要更多的副本。 - coderforlife

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接