遍历n维数组的通用函数

3
使用Cython,有没有一种方法可以编写快速的通用函数,适用于不同维度的数组?例如,对于这个简单的去混叠函数的情况:
import numpy as np
cimport numpy as np

ctypedef np.uint8_t DTYPEb_t
ctypedef np.complex128_t DTYPEc_t


def dealiasing1D(DTYPEc_t[:, :] data, 
                 DTYPEb_t[:] where_dealiased):
    """Dealiasing data for 1D solvers."""
    cdef Py_ssize_t ik, i0, nk, n0

    nk = data.shape[0]
    n0 = data.shape[1]

    for ik in range(nk):
        for i0 in range(n0):
            if where_dealiased[i0]:
                data[ik, i0] = 0.


def dealiasing2D(DTYPEc_t[:, :, :] data, 
                 DTYPEb_t[:, :] where_dealiased):
    """Dealiasing data for 2D solvers."""
    cdef Py_ssize_t ik, i0, i1, nk, n0, n1

    nk = data.shape[0]
    n0 = data.shape[1]
    n1 = data.shape[2]

    for ik in range(nk):
        for i0 in range(n0):
            for i1 in range(n1):
                if where_dealiased[i0, i1]:
                    data[ik, i0, i1] = 0.


def dealiasing3D(DTYPEc_t[:, :, :, :] data, 
                 DTYPEb_t[:, :, :] where_dealiased):
    """Dealiasing data for 3D solvers."""
    cdef Py_ssize_t ik, i0, i1, i2, nk, n0, n1, n2

    nk = data.shape[0]
    n0 = data.shape[1]
    n1 = data.shape[2]
    n2 = data.shape[3]

    for ik in range(nk):
        for i0 in range(n0):
            for i1 in range(n1):
                for i2 in range(n2):
                    if where_dealiased[i0, i1, i2]:
                        data[ik, i0, i1, i2] = 0.

在这里,我需要针对一维、二维和三维情况编写三个函数。有没有一个好的方法可以编写一个函数来处理所有(合理的)维度?
PS:在这里,我尝试使用memoryviews,但我不确定这是否是正确的方法。我惊讶的是,cython -a 命令产生的注释html中的 if where_dealiased[i0]: data[ik, i0] = 0. 行没有被标记为白色。这是有问题吗?
2个回答

2
第一件事我要说的是,保留这3个函数有其原因。如果使用更通用的函数,您可能会错过Cython编译器和C编译器的优化。
将这3个函数包装成一个函数是可行的,只需要将两个数组作为Python对象传入,检查形状并调用相关的其他函数即可。
但是,如果要尝试这样做,我会先编写最高维度的函数,然后使用“new axis”符号将低维度的数组重新转换为更高维度的数组。
cdef np.uint8_t [:] a1d = np.zeros((256, ), np.uint8) # 1d
cdef np.uint8_t [:, :] a2d = a1d[None, :]             # 2d
cdef np.uint8_t [:, :, :] a3d = a1d[None, None, :]    # 3d
a2d[0, 100] = 42
a3d[0, 0, 200] = 108
print(a1d[100], a1d[200])
# (42, 108)

cdef np.uint8_t [:, :] data2d = np.zeros((128, 256), np.uint8) #2d
cdef np.uint8_t [:, :, :, :] data4d = data2d[None, None, :, :] #4d
data4d[0, 0, 42, 108] = 64
print(data2d[42, 108])
# 64

如您所见,内存视图可以转换为更高维度,并可用于修改原始数据。在将新视图传递给最高维度的函数之前,您可能仍希望编写一个包装函数来执行这些技巧。我怀疑这个技巧在您的情况下会非常有效,但是您必须进行试验才能知道它是否符合您对数据的要求。
关于您的PS:有一个非常简单的解释。 "额外的代码" 是生成索引错误、类型错误并允许您使用[-1]从数组末尾而不是开头(环绕)索引的代码。 您可以通过使用编译器指令禁用这些额外的Python功能,并将其减少到C数组功能,例如,要从整个文件中消除此额外代码,您可以在文件开头包含注释。
# cython: boundscheck=False, wraparound=False, nonecheck=False

编译器指令也可以使用装饰器在函数级别上应用。文档进行了解释。

1
您可以使用numpy.ndindex()np.ndarray对象的strided属性以一般方式读取扁平化的数组,从而确定位置:
indices[0]*strides[0] + indices[1]*strides[1] + ... + indices[n]*strides[n]

strides是一个1-D数组时,可以轻松地通过执行(strides*indices).sum()来完成。以下代码显示了如何构建一个可行的示例:

#cython profile=True
#blacython wraparound=False
#blacython boundscheck=False
#blacython nonecheck=False
#blacython cdivision=True
cimport numpy as np
import numpy as np

def readNDArray(x):
    if not isinstance(x, np.ndarray):
        raise ValueError('x must be a valid np.ndarray object')
    if x.itemsize != 8:
        raise ValueError('x.dtype must be float64')
    cdef np.ndarray[double, ndim=1] v # view of x
    cdef np.ndarray[int, ndim=1] strides
    cdef int pos

    shape = list(x.shape)
    strides = np.array([s//x.itemsize for s in x.strides], dtype=np.int32)
    v = x.ravel()
    for indices in np.ndindex(*shape):
        pos = (strides*indices).sum()
        v[pos] = 2.
    return np.reshape(v, newshape=shape)

该算法如果原始数组是C连续的,则不会复制原始数组:
def main():
    # case 1
    x = np.array(np.random.random((3,4,5,6)), order='F')
    y = readNDArray(x)
    print(np.may_share_memory(x, y))
    # case 2
    x = np.array(np.random.random((3,4,5,6)), order='C')
    y = readNDArray(x)
    print np.may_share_memory(x, y)
    return 0

结果为:

False
True

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接