如何高效地将矩阵变换应用于NumPy数组的每一行?

5

假设我有一个2d的NumPy ndarray,如下:

[[ 0, 1, 2, 3 ],
 [ 4, 5, 6, 7 ],
 [ 8, 9, 10, 11 ]]

就概念而言,我想要做的是这样:

For each row:
    Transpose the row
    Multiply the transposed row by a transformation matrix
    Transpose the result
    Store the result in the original ndarray, overwriting the original row data

我有一种非常缓慢而暴力的方法,可以实现这个功能:

import numpy as np
transform_matrix = np.matrix( /* 4x4 matrix setup clipped for brevity */ )
for i, row in enumerate( data ):
    tr = row.reshape( ( 4, 1 ) )
    new_row = np.dot( transform_matrix, tr )
    data[i] = new_row.reshape( ( 1, 4 ) )

然而,这似乎是NumPy擅长处理的操作。我认为作为一个新手,在文档中可能会忽略一些基本的东西。有什么指点吗?

请注意,如果创建新的ndarray比原地编辑更快,那么对于我正在做的事情来说,这也是可行的;操作的速度是主要关注点。

2个回答

11
您想要执行的一系列操作等同于以下内容:
data[:] = data.dot(transform_matrix.T)

或者使用一个新数组而不是修改原始数组,这应该会更快一些:

data.dot(transform_matrix.T)

以下是说明:

这里是解释:

For each row:
    Transpose the row

相当于对矩阵进行转置,然后遍历列。

    Multiply the transposed row by a transformation matrix

将矩阵的每一列左乘第二个矩阵,等同于将整个矩阵左乘第二个矩阵。此时,您拥有的是transform_matrix.dot(data.T)

    Transpose the result

矩阵转置的基本属性之一是 transform_matrix.dot(data.T).T 等同于 data.dot(transform_matrix.T)

    Store the result in the original ndarray, overwriting the original row data

切片赋值可以实现此功能。

这里的切片是多余的,我认为。 - alko
1
太好了!感谢您详细的解释,非常有帮助。 - user3089880

4

看起来你需要使用转置操作符

>>> np.random.seed(11)
>>> transform_matrix = np.random.randint(1, 10, (4,4))
>>> np.dot(transform_matrix, data.T).T
matrix([[ 24,  24,  17,  37],
        [ 76, 108,  61, 137],
        [128, 192, 105, 237]])

或者等价地,作为(A*B).T = (B.T * A.T):
>>> np.dot(data, transform_matrix.T)

+1 指向矩阵代数 - 这不是编程问题,而是对矩阵代数的理解不足。 - DCS

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接