如何在Tensorflow 2.x Keras自定义层中使用多个输入?

11

我正在尝试在Tensorflow-Keras的自定义层中使用多个输入。使用方法可以是任何东西,现在它被定义为将掩码与图像相乘。我已经在Stack Overflow上搜索过,但是我找到的唯一答案是针对TF 1.x的,所以没有帮助。

class mul(layers.Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        # I've added pass because this is the simplest form I can come up with.
        pass
          
    def call(self, inputs):
        # magic happens here and multiplications occur
        return(Z)
2个回答

15

编辑:自从TensorFlow v2.3/2.4版本以来,约定使用输入列表作为call方法的参数。对于keras(而不是tf.keras),我认为下面的回答仍然适用。

实现多个输入可以在您的类的call方法中完成,有两种选择:

  • 列表输入,这里期望inputs参数是包含所有输入的列表,优点是它可以是可变大小的。您可以通过索引列表或使用=运算符解压参数:

      def call(self, inputs):
          Z = inputs[0] * inputs[1]
    
          #Alternate
          input1, input2 = inputs
          Z = input1 * input2
    
          return Z
    
  • call 方法中的多个输入参数可行,但在定义层时参数数量固定:

  •   def call(self, input1, input2):
          Z = input1 * input2
    
          return Z
    

无论你选择哪种方法来实现它,这取决于你需要固定大小还是可变大小的参数。当然,每种方法都会改变调用层所需的方式,可以通过传递参数列表或在函数调用中逐个传递参数来实现。

在第一种方法中,您还可以使用*args 来允许具有可变数量的参数的 call 方法,但总体来说,像 ConcatenateAdd 这样需要多个输入的 Keras 层是使用列表实现的。


2
你必须使用列表,而不是多个参数。请参阅此“文档”:https://github.com/tensorflow/tensorflow/blob/v2.4.0/tensorflow/python/keras/engine/base_layer.py#L930-L941 - max
2
多个输入参数违反了tf.keras.Layer.call()的契约(https://www.tensorflow.org/api_docs/python/tf/keras/layers/Layer#call),该契约明确规定`inputs`应为多个输入张量的列表/元组。 - Grisha

4

试着用这种方式

class mul(layers.Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        # I've added pass because this is the simplest form I can come up with.
        pass

    def call(self, inputs):
        inp1, inp2 = inputs
        Z = inp1*inp2
        return Z

inp1 = Input((10))
inp2 = Input((10))
x = mul()([inp1,inp2])
x = Dense(1)(x)
model = Model([inp1,inp2],x)
model.summary()

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接