在numpy中,是否可以将ndarray的对角线形成一个视图?

3

简单的切片可以将父数组形成视图。视图的步长通常是父数组步长的倍数。

给定步长为(s0,s1)的二维父数组,具有步长(s0+s1)的一维数组可以在父数组的对角线上得到视图。

是否有一种方法可以在顶级Python / numpy中创建这样的视图?谢谢。

3个回答

3
使用as_strided函数,我可以完成您想要的操作。
In [298]: X=np.eye(5)
In [299]: X.strides
Out[299]: (40, 8)
In [300]: np.lib.stride_tricks.as_strided(X,shape=(5,),strides=(48,))
Out[300]: array([ 1.,  1.,  1.,  1.,  1.])

尽管有些人会认为as_strided比大多数numpy Python代码更接近“内部机理”,但我可以通过对扁平数组进行索引来完成相同的步进操作:
In [311]: X.ravel()[::6]
Out[311]: array([ 1.,  2.,  3.,  4.,  5.]) 

(这里的X值是通过一个view测试进行更改的。)

这是有用的信息,但 .as_strided 并不返回一个视图,而是返回一个副本。让 Y = np.lib.stride_tricks.as_strided(X,shape=(5,),strides=(48,)),然后 Y.base is X 的结果是 False - user40314
1
__array_interface__['data']的值是相同的。我不知道为什么.data属性会给出不同的hex位置。同时,改变对角线上的值也会改变原始数据中的值。 - hpaulj
有趣的是,你的回答揭示了步长不需要互相整除,就像切片形成时那样。特别地,下面这个ndarray是有效的:Z = np.lib.stride_tricks.as_strided(X,shape=(5,5),strides=(7,5)) - user40314
使用as_strided时要小心,它允许你在数据缓冲区之外索引内存。因此,如果不小心使用,可能会很危险。 - hpaulj
1
X.ravel()[::6] 做的是相同的事情 - 将数组视为1维,然后每隔6个元素取一个。 - hpaulj
显示剩余3条评论

2
如果您使用的是 numpy 1.9 或更高版本,并且只需要一个只读视图,则可以使用 numpy.diagonal。docstring 表示,在 numpy 的某个未来版本中,numpy.diagonal 将返回一个读/写视图,但这并不能帮助您现在。如果您需要一个读/写视图,则 @hpaulj 建议使用 as_strided。我建议使用类似于这样的语句。
diag = as_strided(a, shape=(min(a.shape),), strides=(sum(a.strides),))

请务必阅读as_strided文档字符串中的“注释”部分。

0
对于支持超过2个维度的版本:
import numpy as np


def diagonal_view(array, axis1=0, axis2=1):
    """Return a view of the array diagonal."""
    assert array.ndim >= 2
    axis1, axis2 = min([axis1, axis2]), max([axis1, axis2])
    shape = list(array.shape)
    new = min([shape[axis1], shape[axis2]])
    shape.pop(axis1)
    shape.pop(axis2 - 1)
    shape.append(new)
    strides = list(array.strides)
    new = strides[axis1] + strides[axis2]
    strides.pop(axis1)
    strides.pop(axis2 - 1)
    strides.append(new)
    diag = np.lib.stride_tricks.as_strided(array, shape=shape, strides=strides)
    return diag


def test_diagonal_view():
    # test correspondence with np.diagonal
    for array in [
            np.random.randn(10, 4),
            np.random.randn(10, 4).T,
            np.random.randn(10, 4, 8),
            np.random.randn(10, 4, 8).T,
            np.random.randn(10, 4, 8).swapaxes(0, 1),
            np.random.randn(3, 4, 8, 5),
            np.random.randn(3, 4, 8, 5).swapaxes(0, 2),
    ]:
        for axis1 in range(array.ndim):
            for axis2 in range(array.ndim):
                if axis1 != axis2:
                    result = diagonal_view(array, axis1=axis1, axis2=axis2)
                    # compare with np.diagonal
                    reference = np.diagonal(array, axis1=axis1, axis2=axis2)
                    np.testing.assert_array_equal(result, reference)
                    # test that this is a modifiable view
                    result += 1
                    reference = np.diagonal(array, axis1=axis1, axis2=axis2)
                    np.testing.assert_array_equal(result, reference)


test_diagonal_view()

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接