TF2.1: SegNet模型架构问题。度量计算存在Bug,保持不变并收敛到确定的值。

8

我正在使用Tensorflow 2.1.0 构建一个自定义模型 (SegNet)。

第一个问题是需要重复利用最大池化操作的索引,如论文所述。 由于这是编码器-解码器架构,因此需要在解码器中上采样特征图并保留相应索引的值。

然而,在TF中,tf.keras.layers.MaxPool2D层默认不导出这些索引(例如PyTorch中的导出)。 为了获取最大池化操作的索引,需要使用tf.nn.max_pool_with_argmax。 但是,这个操作返回的索引(argmax)是扁平化的格式,需要进一步处理才能在网络的其他部分中使用。

为了实现一个执行MaxPooling2D并导出这些索引(扁平化)的层,我在keras中定义了一个自定义层。

class MaxPoolingWithArgmax2D(Layer):

def __init__(
        self,
        pool_size=(2, 2),
        strides=2,
        padding='same',
        **kwargs):
    super(MaxPoolingWithArgmax2D, self).__init__(**kwargs)
    self.padding = padding
    self.pool_size = pool_size
    self.strides = strides

def call(self, inputs, **kwargs):
    padding = self.padding
    pool_size = self.pool_size
    strides = self.strides
    output, argmax = tf.nn.max_pool_with_argmax(
        inputs,
        ksize=pool_size,
        strides=strides,
        padding=padding.upper(),
        output_dtype=tf.int64)
    return output, argmax

显然,这一层在网络的编码部分中使用,因此需要一个相应的解码层来执行反向操作(UpSampling2D),利用索引进行操作(有关此操作的更多详细信息请参见论文)。

经过一些研究,我发现了遗留代码(TF<2.1.0),并对其进行了调整以执行该操作。但无论如何,我并不完全相信这段代码能够很好地工作,事实上有一些我不喜欢的东西。

class MaxUnpooling2D(Layer):
def __init__(self, size=(2, 2), **kwargs):
    super(MaxUnpooling2D, self).__init__(**kwargs)
    self.size = size

def call(self, inputs, output_shape=None):
    updates, mask = inputs[0], inputs[1]
    with tf.name_scope(self.name):
        mask = tf.cast(mask, 'int32')
        #input_shape = tf.shape(updates, out_type='int32')
        input_shape = updates.get_shape()

        # This statement is required if I don't want to specify a batch size
        if input_shape[0] == None:
            batches = 1
        else:
            batches = input_shape[0]

        #  calculation new shape
        if output_shape is None:
            output_shape = (
                    batches,
                    input_shape[1]*self.size[0],
                    input_shape[2]*self.size[1],
                    input_shape[3])

        # calculation indices for batch, height, width and feature maps
        one_like_mask = tf.ones_like(mask, dtype='int32')
        batch_shape = tf.concat(
                [[batches], [1], [1], [1]],
                axis=0)
        batch_range = tf.reshape(
                tf.range(output_shape[0], dtype='int32'),
                shape=batch_shape)
        b = one_like_mask * batch_range
        y = mask // (output_shape[2] * output_shape[3])
        x = (mask // output_shape[3]) % output_shape[2]
        feature_range = tf.range(output_shape[3], dtype='int32')
        f = one_like_mask * feature_range

        # transpose indices & reshape update values to one dimension
        updates_size = tf.size(updates)
        indices = tf.transpose(tf.reshape(
            tf.stack([b, y, x, f]),
            [4, updates_size]))
        values = tf.reshape(updates, [updates_size])
        ret = tf.scatter_nd(indices, values, output_shape)
        return ret

我所担忧的问题有:

  1. 执行解扁平化指数(MaxUnpooling2D)操作时,严格来说需要知道特定的批次大小,但是对于模型验证,我希望它是无或未指定的。
  2. 我不确定这段代码是否与整个库实际上完全兼容。实际上,在fit期间,如果我使用tf.keras.metrics.MeanIoU,值会收敛到0.341,并且在第一个时期之外的每个时期都保持不变。而标准精度指标则运行正常。

深入了解网络架构


接下来是该模型的完整定义。

import tensorflow as tf
import tensorflow.keras as keras
import tensorflow.keras.layers as layers
from tensorflow.keras.layers import Layer


class SegNet:
    def __init__(self, data_shape, classes = 3, batch_size = None):
        self.MODEL_NAME = 'SegNet'
        self.MODEL_VERSION = '0.2'

        self.classes = classes
        self.batch_size = batch_size

        self.build_model(data_shape)

    def build_model(self, data_shape):
        input_shape = (data_shape, data_shape, 3)

        inputs = keras.Input(shape=input_shape, batch_size=self.batch_size, name='Input')

        # Build sequential model

        # Encoding
        encoders = 5
        feature_maps = [64, 128, 256, 512, 512]
        n_convolutions = [2, 2, 3, 3, 3]
        eb_input = inputs
        eb_argmax_indices = []
        for encoder_index in range(encoders):
            encoder_block, argmax_indices = self.encoder_block(
                eb_input, encoder_index, feature_maps[encoder_index], n_convolutions[encoder_index])
            eb_argmax_indices.append(argmax_indices)
            eb_input = encoder_block

        # Decoding
        decoders = encoders
        db_input = encoder_block
        eb_argmax_indices.reverse()
        feature_maps.reverse()
        n_convolutions.reverse()
        d_feature_maps = [512, 512, 256, 128, 64]
        d_n_convolutions = n_convolutions
        for decoder_index in range(decoders):
            decoder_block = self.decoder_block(
                db_input, eb_argmax_indices[decoder_index], decoder_index, d_feature_maps[decoder_index], d_n_convolutions[decoder_index])
            db_input = decoder_block

        output = layers.Softmax()(decoder_block)

        self.model = keras.Model(inputs=inputs, outputs=output, name="SegNet")

    def encoder_block(self, x, encoder_index, feature_maps, n_convolutions):
        bank_input = x
        for conv_index in range(n_convolutions):
            bank = self.eb_layers_bank(
                bank_input, conv_index, feature_maps, encoder_index)
            bank_input = bank

        max_pool, indices = MaxPoolingWithArgmax2D(pool_size=(
            2, 2), strides=2, padding='same', name='EB_{}_MPOOL'.format(encoder_index + 1))(bank)

        return max_pool, indices

    def eb_layers_bank(self, x, bank_index, feature_maps, encoder_index):

        bank_input = x

        conv_l = layers.Conv2D(feature_maps, (3, 3), padding='same', name='EB_{}_BANK_{}_CONV'.format(
            encoder_index + 1, bank_index + 1))(bank_input)
        batch_norm = layers.BatchNormalization(
            name='EB_{}_BANK_{}_BN'.format(encoder_index + 1, bank_index + 1))(conv_l)
        relu = layers.ReLU(name='EB_{}_BANK_{}_RL'.format(
            encoder_index + 1, bank_index + 1))(batch_norm)

        return relu

    def decoder_block(self, x, max_pooling_idices, decoder_index, feature_maps, n_convolutions):
        #bank_input = self.unpool_with_argmax(x, max_pooling_idices)
        bank_input = MaxUnpooling2D(name='DB_{}_UPSAMP'.format(decoder_index + 1))([x, max_pooling_idices])
        #bank_input = layers.UpSampling2D()(x)
        for conv_index in range(n_convolutions):
            if conv_index == n_convolutions - 1:
                last_l_banck = True
            else:
                last_l_banck = False
            bank = self.db_layers_bank(
                bank_input, conv_index, feature_maps, decoder_index, last_l_banck)
            bank_input = bank

        return bank

    def db_layers_bank(self, x, bank_index, feature_maps, decoder_index, last_l_bank):
        bank_input = x

        if (last_l_bank) & (decoder_index == 4):
            conv_l = layers.Conv2D(self.classes, (1, 1), padding='same', name='DB_{}_BANK_{}_CONV'.format(
                decoder_index + 1, bank_index + 1))(bank_input)
            #batch_norm = layers.BatchNormalization(
            #    name='DB_{}_BANK_{}_BN'.format(decoder_index + 1, bank_index + 1))(conv_l)
            return conv_l
        else:

            if (last_l_bank) & (decoder_index > 0):
                conv_l = layers.Conv2D(int(feature_maps / 2), (3, 3), padding='same', name='DB_{}_BANK_{}_CONV'.format(
                    decoder_index + 1, bank_index + 1))(bank_input)
            else:
                conv_l = layers.Conv2D(feature_maps, (3, 3), padding='same', name='DB_{}_BANK_{}_CONV'.format(
                    decoder_index + 1, bank_index + 1))(bank_input)
            batch_norm = layers.BatchNormalization(
                name='DB_{}_BANK_{}_BN'.format(decoder_index + 1, bank_index + 1))(conv_l)
            relu = layers.ReLU(name='DB_{}_BANK_{}_RL'.format(
                decoder_index + 1, bank_index + 1))(batch_norm)

            return relu

    def get_model(self):
        return self.model

这里是model.summary()的输出结果。

Model: "SegNet"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
Input (InputLayer)              [(None, 416, 416, 3) 0                                            
__________________________________________________________________________________________________
EB_1_BANK_1_CONV (Conv2D)       (None, 416, 416, 64) 1792        Input[0][0]                      
__________________________________________________________________________________________________
EB_1_BANK_1_BN (BatchNormalizat (None, 416, 416, 64) 256         EB_1_BANK_1_CONV[0][0]           
__________________________________________________________________________________________________
EB_1_BANK_1_RL (ReLU)           (None, 416, 416, 64) 0           EB_1_BANK_1_BN[0][0]             
__________________________________________________________________________________________________
EB_1_BANK_2_CONV (Conv2D)       (None, 416, 416, 64) 36928       EB_1_BANK_1_RL[0][0]             
__________________________________________________________________________________________________
EB_1_BANK_2_BN (BatchNormalizat (None, 416, 416, 64) 256         EB_1_BANK_2_CONV[0][0]           
__________________________________________________________________________________________________
EB_1_BANK_2_RL (ReLU)           (None, 416, 416, 64) 0           EB_1_BANK_2_BN[0][0]             
__________________________________________________________________________________________________
EB_1_MPOOL (MaxPoolingWithArgma ((None, 208, 208, 64 0           EB_1_BANK_2_RL[0][0]             
__________________________________________________________________________________________________
EB_2_BANK_1_CONV (Conv2D)       (None, 208, 208, 128 73856       EB_1_MPOOL[0][0]                 
__________________________________________________________________________________________________
EB_2_BANK_1_BN (BatchNormalizat (None, 208, 208, 128 512         EB_2_BANK_1_CONV[0][0]           
__________________________________________________________________________________________________
EB_2_BANK_1_RL (ReLU)           (None, 208, 208, 128 0           EB_2_BANK_1_BN[0][0]             
__________________________________________________________________________________________________
EB_2_BANK_2_CONV (Conv2D)       (None, 208, 208, 128 147584      EB_2_BANK_1_RL[0][0]             
__________________________________________________________________________________________________
EB_2_BANK_2_BN (BatchNormalizat (None, 208, 208, 128 512         EB_2_BANK_2_CONV[0][0]           
__________________________________________________________________________________________________
EB_2_BANK_2_RL (ReLU)           (None, 208, 208, 128 0           EB_2_BANK_2_BN[0][0]             
__________________________________________________________________________________________________
EB_2_MPOOL (MaxPoolingWithArgma ((None, 104, 104, 12 0           EB_2_BANK_2_RL[0][0]             
__________________________________________________________________________________________________
EB_3_BANK_1_CONV (Conv2D)       (None, 104, 104, 256 295168      EB_2_MPOOL[0][0]                 
__________________________________________________________________________________________________
EB_3_BANK_1_BN (BatchNormalizat (None, 104, 104, 256 1024        EB_3_BANK_1_CONV[0][0]           
__________________________________________________________________________________________________
EB_3_BANK_1_RL (ReLU)           (None, 104, 104, 256 0           EB_3_BANK_1_BN[0][0]             
__________________________________________________________________________________________________
EB_3_BANK_2_CONV (Conv2D)       (None, 104, 104, 256 590080      EB_3_BANK_1_RL[0][0]             
__________________________________________________________________________________________________
EB_3_BANK_2_BN (BatchNormalizat (None, 104, 104, 256 1024        EB_3_BANK_2_CONV[0][0]           
__________________________________________________________________________________________________
EB_3_BANK_2_RL (ReLU)           (None, 104, 104, 256 0           EB_3_BANK_2_BN[0][0]             
__________________________________________________________________________________________________
EB_3_BANK_3_CONV (Conv2D)       (None, 104, 104, 256 590080      EB_3_BANK_2_RL[0][0]             
__________________________________________________________________________________________________
EB_3_BANK_3_BN (BatchNormalizat (None, 104, 104, 256 1024        EB_3_BANK_3_CONV[0][0]           
__________________________________________________________________________________________________
EB_3_BANK_3_RL (ReLU)           (None, 104, 104, 256 0           EB_3_BANK_3_BN[0][0]             
__________________________________________________________________________________________________
EB_3_MPOOL (MaxPoolingWithArgma ((None, 52, 52, 256) 0           EB_3_BANK_3_RL[0][0]             
__________________________________________________________________________________________________
EB_4_BANK_1_CONV (Conv2D)       (None, 52, 52, 512)  1180160     EB_3_MPOOL[0][0]                 
__________________________________________________________________________________________________
EB_4_BANK_1_BN (BatchNormalizat (None, 52, 52, 512)  2048        EB_4_BANK_1_CONV[0][0]           
__________________________________________________________________________________________________
EB_4_BANK_1_RL (ReLU)           (None, 52, 52, 512)  0           EB_4_BANK_1_BN[0][0]             
__________________________________________________________________________________________________
EB_4_BANK_2_CONV (Conv2D)       (None, 52, 52, 512)  2359808     EB_4_BANK_1_RL[0][0]             
__________________________________________________________________________________________________
EB_4_BANK_2_BN (BatchNormalizat (None, 52, 52, 512)  2048        EB_4_BANK_2_CONV[0][0]           
__________________________________________________________________________________________________
EB_4_BANK_2_RL (ReLU)           (None, 52, 52, 512)  0           EB_4_BANK_2_BN[0][0]             
__________________________________________________________________________________________________
EB_4_BANK_3_CONV (Conv2D)       (None, 52, 52, 512)  2359808     EB_4_BANK_2_RL[0][0]             
__________________________________________________________________________________________________
EB_4_BANK_3_BN (BatchNormalizat (None, 52, 52, 512)  2048        EB_4_BANK_3_CONV[0][0]           
__________________________________________________________________________________________________
EB_4_BANK_3_RL (ReLU)           (None, 52, 52, 512)  0           EB_4_BANK_3_BN[0][0]             
__________________________________________________________________________________________________
EB_4_MPOOL (MaxPoolingWithArgma ((None, 26, 26, 512) 0           EB_4_BANK_3_RL[0][0]             
__________________________________________________________________________________________________
EB_5_BANK_1_CONV (Conv2D)       (None, 26, 26, 512)  2359808     EB_4_MPOOL[0][0]                 
__________________________________________________________________________________________________
EB_5_BANK_1_BN (BatchNormalizat (None, 26, 26, 512)  2048        EB_5_BANK_1_CONV[0][0]           
__________________________________________________________________________________________________
EB_5_BANK_1_RL (ReLU)           (None, 26, 26, 512)  0           EB_5_BANK_1_BN[0][0]             
__________________________________________________________________________________________________
EB_5_BANK_2_CONV (Conv2D)       (None, 26, 26, 512)  2359808     EB_5_BANK_1_RL[0][0]             
__________________________________________________________________________________________________
EB_5_BANK_2_BN (BatchNormalizat (None, 26, 26, 512)  2048        EB_5_BANK_2_CONV[0][0]           
__________________________________________________________________________________________________
EB_5_BANK_2_RL (ReLU)           (None, 26, 26, 512)  0           EB_5_BANK_2_BN[0][0]             
__________________________________________________________________________________________________
EB_5_BANK_3_CONV (Conv2D)       (None, 26, 26, 512)  2359808     EB_5_BANK_2_RL[0][0]             
__________________________________________________________________________________________________
EB_5_BANK_3_BN (BatchNormalizat (None, 26, 26, 512)  2048        EB_5_BANK_3_CONV[0][0]           
__________________________________________________________________________________________________
EB_5_BANK_3_RL (ReLU)           (None, 26, 26, 512)  0           EB_5_BANK_3_BN[0][0]             
__________________________________________________________________________________________________
EB_5_MPOOL (MaxPoolingWithArgma ((None, 13, 13, 512) 0           EB_5_BANK_3_RL[0][0]             
__________________________________________________________________________________________________
DB_1_UPSAMP (MaxUnpooling2D)    (1, 26, 26, 512)     0           EB_5_MPOOL[0][0]                 
                                                                 EB_5_MPOOL[0][1]                 
__________________________________________________________________________________________________
DB_1_BANK_1_CONV (Conv2D)       (1, 26, 26, 512)     2359808     DB_1_UPSAMP[0][0]                
__________________________________________________________________________________________________
DB_1_BANK_1_BN (BatchNormalizat (1, 26, 26, 512)     2048        DB_1_BANK_1_CONV[0][0]           
__________________________________________________________________________________________________
DB_1_BANK_1_RL (ReLU)           (1, 26, 26, 512)     0           DB_1_BANK_1_BN[0][0]             
__________________________________________________________________________________________________
DB_1_BANK_2_CONV (Conv2D)       (1, 26, 26, 512)     2359808     DB_1_BANK_1_RL[0][0]             
__________________________________________________________________________________________________
DB_1_BANK_2_BN (BatchNormalizat (1, 26, 26, 512)     2048        DB_1_BANK_2_CONV[0][0]           
__________________________________________________________________________________________________
DB_1_BANK_2_RL (ReLU)           (1, 26, 26, 512)     0           DB_1_BANK_2_BN[0][0]             
__________________________________________________________________________________________________
DB_1_BANK_3_CONV (Conv2D)       (1, 26, 26, 512)     2359808     DB_1_BANK_2_RL[0][0]             
__________________________________________________________________________________________________
DB_1_BANK_3_BN (BatchNormalizat (1, 26, 26, 512)     2048        DB_1_BANK_3_CONV[0][0]           
__________________________________________________________________________________________________
DB_1_BANK_3_RL (ReLU)           (1, 26, 26, 512)     0           DB_1_BANK_3_BN[0][0]             
__________________________________________________________________________________________________
DB_2_UPSAMP (MaxUnpooling2D)    (1, 52, 52, 512)     0           DB_1_BANK_3_RL[0][0]             
                                                                 EB_4_MPOOL[0][1]                 
__________________________________________________________________________________________________
DB_2_BANK_1_CONV (Conv2D)       (1, 52, 52, 512)     2359808     DB_2_UPSAMP[0][0]                
__________________________________________________________________________________________________
DB_2_BANK_1_BN (BatchNormalizat (1, 52, 52, 512)     2048        DB_2_BANK_1_CONV[0][0]           
__________________________________________________________________________________________________
DB_2_BANK_1_RL (ReLU)           (1, 52, 52, 512)     0           DB_2_BANK_1_BN[0][0]             
__________________________________________________________________________________________________
DB_2_BANK_2_CONV (Conv2D)       (1, 52, 52, 512)     2359808     DB_2_BANK_1_RL[0][0]             
__________________________________________________________________________________________________
DB_2_BANK_2_BN (BatchNormalizat (1, 52, 52, 512)     2048        DB_2_BANK_2_CONV[0][0]           
__________________________________________________________________________________________________
DB_2_BANK_2_RL (ReLU)           (1, 52, 52, 512)     0           DB_2_BANK_2_BN[0][0]             
__________________________________________________________________________________________________
DB_2_BANK_3_CONV (Conv2D)       (1, 52, 52, 256)     1179904     DB_2_BANK_2_RL[0][0]             
__________________________________________________________________________________________________
DB_2_BANK_3_BN (BatchNormalizat (1, 52, 52, 256)     1024        DB_2_BANK_3_CONV[0][0]           
__________________________________________________________________________________________________
DB_2_BANK_3_RL (ReLU)           (1, 52, 52, 256)     0           DB_2_BANK_3_BN[0][0]             
__________________________________________________________________________________________________
DB_3_UPSAMP (MaxUnpooling2D)    (1, 104, 104, 256)   0           DB_2_BANK_3_RL[0][0]             
                                                                 EB_3_MPOOL[0][1]                 
__________________________________________________________________________________________________
DB_3_BANK_1_CONV (Conv2D)       (1, 104, 104, 256)   590080      DB_3_UPSAMP[0][0]                
__________________________________________________________________________________________________
DB_3_BANK_1_BN (BatchNormalizat (1, 104, 104, 256)   1024        DB_3_BANK_1_CONV[0][0]           
__________________________________________________________________________________________________
DB_3_BANK_1_RL (ReLU)           (1, 104, 104, 256)   0           DB_3_BANK_1_BN[0][0]             
__________________________________________________________________________________________________
DB_3_BANK_2_CONV (Conv2D)       (1, 104, 104, 256)   590080      DB_3_BANK_1_RL[0][0]             
__________________________________________________________________________________________________
DB_3_BANK_2_BN (BatchNormalizat (1, 104, 104, 256)   1024        DB_3_BANK_2_CONV[0][0]           
__________________________________________________________________________________________________
DB_3_BANK_2_RL (ReLU)           (1, 104, 104, 256)   0           DB_3_BANK_2_BN[0][0]             
__________________________________________________________________________________________________
DB_3_BANK_3_CONV (Conv2D)       (1, 104, 104, 128)   295040      DB_3_BANK_2_RL[0][0]             
__________________________________________________________________________________________________
DB_3_BANK_3_BN (BatchNormalizat (1, 104, 104, 128)   512         DB_3_BANK_3_CONV[0][0]           
__________________________________________________________________________________________________
DB_3_BANK_3_RL (ReLU)           (1, 104, 104, 128)   0           DB_3_BANK_3_BN[0][0]             
__________________________________________________________________________________________________
DB_4_UPSAMP (MaxUnpooling2D)    (1, 208, 208, 128)   0           DB_3_BANK_3_RL[0][0]             
                                                                 EB_2_MPOOL[0][1]                 
__________________________________________________________________________________________________
DB_4_BANK_1_CONV (Conv2D)       (1, 208, 208, 128)   147584      DB_4_UPSAMP[0][0]                
__________________________________________________________________________________________________
DB_4_BANK_1_BN (BatchNormalizat (1, 208, 208, 128)   512         DB_4_BANK_1_CONV[0][0]           
__________________________________________________________________________________________________
DB_4_BANK_1_RL (ReLU)           (1, 208, 208, 128)   0           DB_4_BANK_1_BN[0][0]             
__________________________________________________________________________________________________
DB_4_BANK_2_CONV (Conv2D)       (1, 208, 208, 64)    73792       DB_4_BANK_1_RL[0][0]             
__________________________________________________________________________________________________
DB_4_BANK_2_BN (BatchNormalizat (1, 208, 208, 64)    256         DB_4_BANK_2_CONV[0][0]           
__________________________________________________________________________________________________
DB_4_BANK_2_RL (ReLU)           (1, 208, 208, 64)    0           DB_4_BANK_2_BN[0][0]             
__________________________________________________________________________________________________
DB_5_UPSAMP (MaxUnpooling2D)    (1, 416, 416, 64)    0           DB_4_BANK_2_RL[0][0]             
                                                                 EB_1_MPOOL[0][1]                 
__________________________________________________________________________________________________
DB_5_BANK_1_CONV (Conv2D)       (1, 416, 416, 64)    36928       DB_5_UPSAMP[0][0]                
__________________________________________________________________________________________________
DB_5_BANK_1_BN (BatchNormalizat (1, 416, 416, 64)    256         DB_5_BANK_1_CONV[0][0]           
__________________________________________________________________________________________________
DB_5_BANK_1_RL (ReLU)           (1, 416, 416, 64)    0           DB_5_BANK_1_BN[0][0]             
__________________________________________________________________________________________________
DB_5_BANK_2_CONV (Conv2D)       (1, 416, 416, 3)     195         DB_5_BANK_1_RL[0][0]             
__________________________________________________________________________________________________
softmax (Softmax)               (1, 416, 416, 3)     0           DB_5_BANK_2_CONV[0][0]           
==================================================================================================
Total params: 29,459,075
Trainable params: 29,443,203
Non-trainable params: 15,872
__________________________________________________________________________________________________

如您所见,在MaxUnpooling2D中我被迫指定一个批处理大小,否则会出现错误,因为存在None值且形状无法正确转换。

当我尝试预测一张图片时,我必须指定正确的批处理维度,否则会出现如下错误:

InvalidArgumentError:  Shapes of all inputs must match: values[0].shape = [4,208,208,64] != values[1].shape = [1,208,208,64]
     [[{{node SegNet/DB_5_UPSAMP/PartitionedCall/PartitionedCall/DB_5_UPSAMP/stack}}]] [Op:__inference_predict_function_70839]

这是由实现要求引起的,目的是为了解开最大池化操作中的索引。


训练图表

这里是一个训练20个epochs的参考。

如您所见,MeanIoU指标是线性的,没有进展,除了第1个epoch之外没有更新。 Mean intersection over union

另一个指标正常工作,损失正确降低。

Loss and accuracy

––––––––––

结论

  1. 有没有更好的方法来实现从最大池化操作中的索引解开和上采样,以便与TF的最新版本更兼容?
  2. 如果实现正确,为什么我的指标被卡在特定值?我在模型中做错了什么吗?

谢谢!


这与以下答案有关:https://stackoverflow.com/questions/50924072/how-to-implement-segnet-with-preserving-max-indexes-in-keras - Willington Cardona
1个回答

0

在自定义层中,可以以两种方式使用未知批次大小进行重塑。

如果您知道其余形状,请使用-1作为批次大小进行重塑:

假设您知道期望数组的大小:

import tensorflow.keras.backend as K
reshaped = K.reshape(original, (-1, x, y, channels))

假设您不知道大小,则可以使用K.shape将其作为张量获取形状:
inputs_shape = K.shape(inputs)
batch_size = inputs_shape[:1]
x = inputs_shape[1:2]
y = inputs_shape[2:3]
ch = inputs_shape[3:]

#you can then concatenate these and operate them (notice I kept them as 1D vector, not as scalar)
newShape = K.concatenate([batch_size, x, y, ch]) #of course you will make your operations

曾经我自己做了一个Segnet的版本,没有使用索引,而是保留了一个独热编码版本。虽然需要额外的操作,但可能效果很好:

def get_indices(original, unpooled):
    is_equal = K.equal(original, unpooled)
    return K.cast(is_equal, K.floatx())

previous_output = ...
pooled = MaxPooling2D()(previous_output)
unpooled = UpSampling2D()(pooled)

one_hot_indices = Lambda(get_indices)([previous_output, unpooled])

然后在上采样之后,我连接这些索引并传递一个新的卷积:

some_output = ...
upsampled = UpSampling2D()(some_output)
with_indices = Concatenate([upsampled, one_hot_indices])
upsampled = Conv2D(...)(with_indices)

它以不同的方式应用索引,而不是直接在上采样中使用。它会在其后添加索引,并使用卷积来混合它们的效果。 - Daniel Möller
当我使用时,它确实在我的模型中产生了效果。 - Daniel Möller
没有权重的卷积,对吧?因为它被用作一个裸的数学运算,我想象。那么它固定在一个常数值的度量呢?有什么线索吗? - rpasianotto
一般的卷积,使用常规权重。我让神经网络完全决定如何处理索引。 - Daniel Möller
好的,那么它将会更容易进行训练。 - rpasianotto
显示剩余2条评论

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接