使numpy中的矩阵乘法运算符@适用于标量

20
在Python 3.5中,引入了矩阵乘法运算符@,该运算符遵循PEP465。例如,在numpy中,这是通过matmul运算符实现的。
然而,正如PEP所提出的,当使用标量操作数调用numpy运算符时,会抛出异常。
>>> import numpy as np
>>> np.array([[1,2],[3,4]]) @ np.array([[1,2],[3,4]])    # works
array([[ 7, 10],
       [15, 22]])
>>> 1 @ 2                                                # doesn't work
Traceback (most recent call last):
  File "<input>", line 1, in <module>
TypeError: unsupported operand type(s) for @: 'int' and 'int'

这对我来说是一个真正的难题,因为我正在实现数值信号处理算法,应该适用于标量和矩阵。两种情况下的方程在数学上完全等价,这并不奇怪,因为“1-D x 1-D matrix multiplication”等同于标量乘法。然而,当前状态迫使我编写重复的代码以正确处理两种情况。
因此,考虑到当前状态并不令人满意,有没有任何合理的方法可以使@运算符适用于标量?我考虑为标量数据类型添加自定义__matmul__(self, other)方法,但考虑到涉及的内部数据类型数量,这似乎很麻烦。我能否更改numpy数组数据类型的__matmul__方法的实现,以便不会为1x1数组操作数抛出异常?
另外,这个设计决策的原理是什么呢?一时之间,我想不出任何不实现该运算符的强有力理由。

1
[1] @ [2] 怎么样?标量已经有了 *,为什么要重复呢? - furas
10
听起来真正的问题是你的代码有时返回标量,有时返回矩阵。为什么不重构代码,使其返回1 x 1矩阵而不是标量呢?或者编写一个快速函数,将矩阵或标量作为输入,返回该矩阵或一个包含该标量的1x1矩阵。 - Patrick Haugh
4
关于为什么,从您提供的 PEP 中可以看到:“0d(标量)输入会引发错误。标量与矩阵相乘是一种在数学和算法上都与矩阵乘矩阵操作不同的操作,并且已经由逐元素的*运算符覆盖。因此,允许标量@矩阵既需要一个不必要的特例,又违反了“用任何方法都可以完成任务”的原则。” - Patrick Haugh
2
当我写 [1] @ [2] 时,我想到的是类似于 np.array([1]) @ np.array([2]) 的东西,但我应该描述一下 :) - furas
1
也许你可以使用 atleast_1d,例如 np.atleast_1d(5) @ np.atleast_1d(6) 就可以正常工作。 - Alex Riley
显示剩余6条评论
1个回答

9

正如ajcr所建议的那样,您可以通过强制对被乘对象进行最小维度处理来解决此问题。有两个合理的选择:atleast_1datleast_2d,它们在返回@的类型方面具有不同的结果:标量与1x1 2D数组。

x = 3
y = 5
z = np.atleast_1d(x) @ np.atleast_1d(y)   # returns 15 
z = np.atleast_2d(x) @ np.atleast_2d(y)   # returns array([[15]])

然而:
  • 如果x和y是1D数组,本来可以正常相乘,但使用atleast_2d会导致错误
  • 使用atleast_1d会得到一个标量或矩阵的积,但你不知道哪种情况
  • 这两种方法都比np.dot(x, y)更冗长,而np.dot(x, y)可以处理所有情况
此外,atleast_1d版本也有与标量@标量 = 标量相同的缺陷:你不知道输出可以做什么。z.T或z.shape会抛出错误吗?它们适用于1x1矩阵,但不适用于标量。在Python环境下,如果忽略标量和1x1数组之间的区别,也就必须放弃后者具有的所有方法和属性。

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接