强制使用__rmul__()进行乘法运算而不是Numpy数组的__mul__()或绕过广播技术。

5
这个问题与 Overriding other __rmul__ with your class's __mul__中提出的问题相似,但我认为这不仅仅是数值数据的问题。此外该问题没有得到解答,我也不想使用矩阵乘法@来进行操作。因此,提出了这个问题。
我有一个对象,可以接受与标量和数字数组的乘法。通常情况下,左乘法很好用,因为它使用myobj()方法,但在右乘法中,NumPy使用广播规则,并给出元素结果,dtype=object
这也导致无法检查数组的大小是否兼容。
因此,问题是:

有没有一种方法可以强制NumPy数组查找其他对象的__rmul__()而不是广播并执行元素级__mul__()

在我的特定情况下,这个对象是一个MIMO(多输入,多输出)传递函数矩阵(或者说滤波器系数矩阵),因此矩阵乘法在加法和线性系统乘法方面具有特殊含义。因此,在每个条目中都有SISO系统。
import numpy as np

class myobj():
    def __init__(self):
        pass

    def __mul__(self, other):
        if isinstance(other, type(np.array([0.]))):
            if other.size == 1:
                print('Scalar multiplication')
            else:
                print('Multiplication of arrays')

    def __rmul__(self, other):
        if isinstance(other, type(np.array([0.]))):
            if other.size == 1:
                print('Scalar multiplication')
            else:
                print('Multiplication of arrays')

A = myobj()
a = np.array([[[1+1j]]])  # some generic scalar
B = np.random.rand(3, 3)

通过这些定义,以下命令显示了不希望出现的行为。

In [123]: A*a
Scalar multiplication

In [124]: a*A
Out[124]: array([[[None]]], dtype=object)

In [125]: B*A
Out[125]: 
array([[None, None, None],
       [None, None, None],
       [None, None, None]], dtype=object)

In [126]: A*B
Multiplication of arrays

In [127]: 5 * A

In [128]: A.__rmul__(B)  # This is the desired behavior for B*A
Multiplication of arrays

在你的链接中提到了 @,因为它是一个类似于 np.dot 的操作符,而不是解决 rmul 问题的地址。 - hpaulj
1
认真对待右侧关于子类化的讨论。使用mulrmul是Python语法基础问题。 - hpaulj
@hpaulj 这也是一种类似于 dot() 的操作,但我想以专门的方式处理行向量乘法。关键点在于广播。我对 mul 和 rmul 都很满意,因为如果不进行广播,numpy 将无法看到它如何进行乘法运算。请参见最后一个示例,它可以正常工作。 - percusse
2个回答

3

默认情况下,NumPy假定未知对象(不继承ndarray)是标量,并且需要在任何NumPy数组的每个元素上“向量化”乘法。

要自己控制操作,您需要设置__array_priority__(最向后兼容)或__array_ufunc__(仅适用于NumPy 1.13+)。例如:

class myworkingobj(myobj):
    __array_priority__ = 1000

A = myworkingobj()
B = np.random.rand(3, 3)
B * A  # Multiplication of arrays

这确实是一个好消息。非常感谢你。 - percusse

1
我会尝试演示正在发生的事情。
In [494]: B=np.random.rand(3,3)

基础类:

In [497]: class myobj():
     ...:     pass
     ...: 
In [498]: B*myobj()
...

TypeError: unsupported operand type(s) for *: 'float' and 'myobj'

添加一个__mul__
In [500]: class myobj():
     ...:     pass
     ...:     def __mul__(self,other):
     ...:         print('myobj mul')
     ...:         return 12.3
     ...: 
In [501]: B*myobj()
...
TypeError: unsupported operand type(s) for *: 'float' and 'myobj'
In [502]: myobj()*B
myobj mul
Out[502]: 12.3

添加一个 rmul:
In [515]: class myobj():
     ...:     pass
     ...:     def __mul__(self,other):
     ...:         print('myobj mul',other)
     ...:         return 12.3
     ...:     def __rmul__(self,other):
     ...:         print('myobj rmul',other)
     ...:         return 4.32
     ...: 
In [516]: B*myobj()
myobj rmul 0.792751549595306
myobj rmul 0.5668783619454384
myobj rmul 0.2196204913660168
myobj rmul 0.5474970289273348
myobj rmul 0.2079367474424587
myobj rmul 0.5374571198848628
myobj rmul 0.35748803226628456
myobj rmul 0.41306113085906715
myobj rmul 0.499598995529441
Out[516]: 
array([[4.32, 4.32, 4.32],
       [4.32, 4.32, 4.32],
       [4.32, 4.32, 4.32]], dtype=object)

B*myobj()被赋给B,作为B.__mul__(myobj()),它会对B的每个元素执行myobj().__rmul__(i)

myobj()*B中,翻译为myobj.__mul__(B)

In [517]: myobj()*B
myobj mul [[ 0.79275155  0.56687836  0.21962049]
 [ 0.54749703  0.20793675  0.53745712]
 [ 0.35748803  0.41306113  0.499599  ]]
Out[517]: 12.3

In [518]: myobj().__rmul__(B)
myobj rmul [[ 0.79275155  0.56687836  0.21962049]
 [ 0.54749703  0.20793675  0.53745712]
 [ 0.35748803  0.41306113  0.499599  ]]
Out[518]: 4.32

myobj中,无法覆盖B*myobj()翻译成B.__mul__(myobj())的任何操作。如果需要更多的操作控制,请使用函数或方法。与解释器作对很困难。

我觉得我无法很好地解释。我不是试图取消NumPy的乘法。我是在尝试阻止它向单个乘法广播。其他方面,我已经非常熟悉了。 - percusse
“broadcasting” 是由 B.__mul__ 完成的。您可能需要对 ndarray 进行子类化,并覆盖 __mul__ 方法? - hpaulj
然后我必须强制用户使用子类而不是常规数组,这就破坏了mul和rmul整个实现的目的。 - percusse

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接