修改多维numpy数组中的对角线

3
我有一个形状为(7,3,7,3)的多维numpy数组,我想修改其中轴0和轴2重合的广义对角线。这个广义对角线将被定义为数组中那些0号和2号索引相同的元素,并且将具有形状(3,3,7)。
做法:
arr.diagonal(axis1=0, axis2=2)

我可以访问对角线元素,但是至少在numpy的1.8.2版本中,我无法“原地”修改它们。据Numpy文档说明,在1.10版本中可能会实现这一点。然而,由于我依赖其他人使用相同的代码,因此更新到numpy 1.10不是一个选项。文档还建议使用.copy()来获得可移植的解决方案,但是.copy()将复制数组,但如果我想修改原始数组的对角线,则没有用处。
另外,我尝试直接索引对角线元素[使用从numpy.indices((7,3,7,3))获取的输入],但没有成功。
在numpy 1.8.2中,如何访问广义对角线的元素以修改原始数组?
1个回答

3
使用numpy.lib.stride_tricks模块中的as_strided函数,可以创建一个通用的对角线视图。与两个轴相关的对角线所在轴的步幅是这些轴步幅之和。
例如:
In [196]: from numpy.lib.stride_tricks import as_strided

创建一个形状为(7,3,7,3)的数组:
In [197]: a = np.arange(21*21).reshape(7,3,7,3)

In [198]: a[5, :, 5, :]
Out[198]: 
array([[330, 331, 332],
       [351, 352, 353],
       [372, 373, 374]])

创建一个与轴0和2相关联的“对角线”视图。该视图的形状为(3,3,7):
In [199]: d = as_strided(a, strides=(a.strides[1], a.strides[3], a.strides[0] + a.strides[2]), shape=(3, 3, 7))

请检查,例如d[:, :, 5]是否与a[5, :, 5, :]相同:

In [200]: d[:, :, 5]
Out[200]: 
array([[330, 331, 332],
       [351, 352, 353],
       [372, 373, 374]])

修改 d 并观察是否会改变 a,以验证 d 是否为 a 的视图:

In [201]: d[1, 1, 5] = -1

In [202]: a[5, :, 5, :]
Out[202]: 
array([[330, 331, 332],
       [351,  -1, 353],
       [372, 373, 374]])

使用as_strided时要小心!如果参数不正确,可能会写入a外的内存,可能导致Python崩溃。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接