深度学习 UNet 收敛性问题

3
我将为图像分割编写一个深度学习UNet模型,用于将RGB 256 * 256p图像转换为灰度图像。我的灵感来自https://github.com/zhixuhao/unet,因此我的神经网络具有以下结构:
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            (None, 256, 256, 3)  0                                            
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 256, 256, 16) 448         input_1[0][0]                    
__________________________________________________________________________________________________
batch_normalization_1 (BatchNor (None, 256, 256, 16) 64          conv2d_1[0][0]                   
__________________________________________________________________________________________________
conv2d_2 (Conv2D)               (None, 256, 256, 16) 2320        batch_normalization_1[0][0]      
__________________________________________________________________________________________________
batch_normalization_2 (BatchNor (None, 256, 256, 16) 64          conv2d_2[0][0]                   
__________________________________________________________________________________________________
max_pooling2d_1 (MaxPooling2D)  (None, 128, 128, 16) 0           batch_normalization_2[0][0]      
__________________________________________________________________________________________________
conv2d_3 (Conv2D)               (None, 128, 128, 32) 4640        max_pooling2d_1[0][0]            
__________________________________________________________________________________________________
batch_normalization_3 (BatchNor (None, 128, 128, 32) 128         conv2d_3[0][0]                   
__________________________________________________________________________________________________
conv2d_4 (Conv2D)               (None, 128, 128, 32) 9248        batch_normalization_3[0][0]      
__________________________________________________________________________________________________
batch_normalization_4 (BatchNor (None, 128, 128, 32) 128         conv2d_4[0][0]                   
__________________________________________________________________________________________________
max_pooling2d_2 (MaxPooling2D)  (None, 64, 64, 32)   0           batch_normalization_4[0][0]      
__________________________________________________________________________________________________
conv2d_5 (Conv2D)               (None, 64, 64, 64)   18496       max_pooling2d_2[0][0]            
__________________________________________________________________________________________________
batch_normalization_5 (BatchNor (None, 64, 64, 64)   256         conv2d_5[0][0]                   
__________________________________________________________________________________________________
conv2d_6 (Conv2D)               (None, 64, 64, 64)   36928       batch_normalization_5[0][0]      
__________________________________________________________________________________________________
batch_normalization_6 (BatchNor (None, 64, 64, 64)   256         conv2d_6[0][0]                   
__________________________________________________________________________________________________
max_pooling2d_3 (MaxPooling2D)  (None, 32, 32, 64)   0           batch_normalization_6[0][0]      
__________________________________________________________________________________________________
conv2d_7 (Conv2D)               (None, 32, 32, 128)  73856       max_pooling2d_3[0][0]            
__________________________________________________________________________________________________
batch_normalization_7 (BatchNor (None, 32, 32, 128)  512         conv2d_7[0][0]                   
__________________________________________________________________________________________________
conv2d_8 (Conv2D)               (None, 32, 32, 128)  147584      batch_normalization_7[0][0]      
__________________________________________________________________________________________________
batch_normalization_8 (BatchNor (None, 32, 32, 128)  512         conv2d_8[0][0]                   
__________________________________________________________________________________________________
dropout_1 (Dropout)             (None, 32, 32, 128)  0           batch_normalization_8[0][0]      
__________________________________________________________________________________________________
max_pooling2d_4 (MaxPooling2D)  (None, 16, 16, 128)  0           dropout_1[0][0]                  
__________________________________________________________________________________________________
conv2d_9 (Conv2D)               (None, 16, 16, 256)  295168      max_pooling2d_4[0][0]            
__________________________________________________________________________________________________
batch_normalization_9 (BatchNor (None, 16, 16, 256)  1024        conv2d_9[0][0]                   
__________________________________________________________________________________________________
conv2d_10 (Conv2D)              (None, 16, 16, 256)  590080      batch_normalization_9[0][0]      
__________________________________________________________________________________________________
batch_normalization_10 (BatchNo (None, 16, 16, 256)  1024        conv2d_10[0][0]                  
__________________________________________________________________________________________________
dropout_2 (Dropout)             (None, 16, 16, 256)  0           batch_normalization_10[0][0]     
__________________________________________________________________________________________________
up_sampling2d_1 (UpSampling2D)  (None, 32, 32, 256)  0           dropout_2[0][0]                  
__________________________________________________________________________________________________
conv2d_11 (Conv2D)              (None, 32, 32, 128)  131200      up_sampling2d_1[0][0]            
__________________________________________________________________________________________________
concatenate_1 (Concatenate)     (None, 32, 32, 256)  0           dropout_1[0][0]                  
                                                                 conv2d_11[0][0]                  
__________________________________________________________________________________________________
conv2d_12 (Conv2D)              (None, 32, 32, 128)  295040      concatenate_1[0][0]              
__________________________________________________________________________________________________
batch_normalization_11 (BatchNo (None, 32, 32, 128)  512         conv2d_12[0][0]                  
__________________________________________________________________________________________________
conv2d_13 (Conv2D)              (None, 32, 32, 128)  147584      batch_normalization_11[0][0]     
__________________________________________________________________________________________________
batch_normalization_12 (BatchNo (None, 32, 32, 128)  512         conv2d_13[0][0]                  
__________________________________________________________________________________________________
up_sampling2d_2 (UpSampling2D)  (None, 64, 64, 128)  0           batch_normalization_12[0][0]     
__________________________________________________________________________________________________
conv2d_14 (Conv2D)              (None, 64, 64, 64)   32832       up_sampling2d_2[0][0]            
__________________________________________________________________________________________________
concatenate_2 (Concatenate)     (None, 64, 64, 128)  0           conv2d_6[0][0]                   
                                                                 conv2d_14[0][0]                  
__________________________________________________________________________________________________
conv2d_15 (Conv2D)              (None, 64, 64, 64)   73792       concatenate_2[0][0]              
__________________________________________________________________________________________________
batch_normalization_13 (BatchNo (None, 64, 64, 64)   256         conv2d_15[0][0]                  
__________________________________________________________________________________________________
conv2d_16 (Conv2D)              (None, 64, 64, 64)   36928       batch_normalization_13[0][0]     
__________________________________________________________________________________________________
batch_normalization_14 (BatchNo (None, 64, 64, 64)   256         conv2d_16[0][0]                  
__________________________________________________________________________________________________
up_sampling2d_3 (UpSampling2D)  (None, 128, 128, 64) 0           batch_normalization_14[0][0]     
__________________________________________________________________________________________________
conv2d_17 (Conv2D)              (None, 128, 128, 32) 8224        up_sampling2d_3[0][0]            
__________________________________________________________________________________________________
concatenate_3 (Concatenate)     (None, 128, 128, 64) 0           conv2d_4[0][0]                   
                                                                 conv2d_17[0][0]                  
__________________________________________________________________________________________________
conv2d_18 (Conv2D)              (None, 128, 128, 32) 18464       concatenate_3[0][0]              
__________________________________________________________________________________________________
batch_normalization_15 (BatchNo (None, 128, 128, 32) 128         conv2d_18[0][0]                  
__________________________________________________________________________________________________
conv2d_19 (Conv2D)              (None, 128, 128, 32) 9248        batch_normalization_15[0][0]     
__________________________________________________________________________________________________
batch_normalization_16 (BatchNo (None, 128, 128, 32) 128         conv2d_19[0][0]                  
__________________________________________________________________________________________________
up_sampling2d_4 (UpSampling2D)  (None, 256, 256, 32) 0           batch_normalization_16[0][0]     
__________________________________________________________________________________________________
conv2d_20 (Conv2D)              (None, 256, 256, 16) 2064        up_sampling2d_4[0][0]            
__________________________________________________________________________________________________
concatenate_4 (Concatenate)     (None, 256, 256, 32) 0           conv2d_2[0][0]                   
                                                                 conv2d_20[0][0]                  
__________________________________________________________________________________________________
conv2d_21 (Conv2D)              (None, 256, 256, 16) 4624        concatenate_4[0][0]              
__________________________________________________________________________________________________
batch_normalization_17 (BatchNo (None, 256, 256, 16) 64          conv2d_21[0][0]                  
__________________________________________________________________________________________________
conv2d_22 (Conv2D)              (None, 256, 256, 16) 2320        batch_normalization_17[0][0]     
__________________________________________________________________________________________________
batch_normalization_18 (BatchNo (None, 256, 256, 16) 64          conv2d_22[0][0]                  
__________________________________________________________________________________________________
conv2d_23 (Conv2D)              (None, 256, 256, 2)  290         batch_normalization_18[0][0]     
__________________________________________________________________________________________________
batch_normalization_19 (BatchNo (None, 256, 256, 2)  8           conv2d_23[0][0]                  
__________________________________________________________________________________________________
dropout_3 (Dropout)             (None, 256, 256, 2)  0           batch_normalization_19[0][0]     
__________________________________________________________________________________________________
MLP_layer (Conv2D)              (None, 256, 256, 1)  3           dropout_3[0][0]                  
==================================================================================================

然而,收敛非常困难,仅适用于非常受限制的参数集:
- 学习速率不大于1e-3,在一些文章中使用1e-2和衰减
- 第一个卷积滤波器数量只有16(下一层为32等等...)
- 批量大小为8或16,而32和64无法工作
- 批量归一化是必要的,而基本模型示例中没有。这应该有助于网络以较少的受限参数学习,如在https://towardsdatascience.com/batch-normalization-theory-and-how-to-use-it-with-tensorflow-1892ca0173adhttps://arxiv.org/pdf/1502.03167.pdf中所述。
另一细节:
- 我检查过我的输入是范围在0到1之间的np.float32类型
- 我试图学习卫星图像cadasters
所以我的问题是:
为什么我的网络不能使用与参考文章相同的参数?
->我不得不使用“慢”参数才能使它正常工作(降低学习速率,降低批量大小,减少卷积层数...)。否则,它会输出一个具有单个像素值的灰色图像。
使用的代码:
SHAPE=256
DIM=3
INITIALIZER='glorot_uniform'
BASE_SIZE=16
LR=0.001


def get_model(pretrained_model: str = None, input_size: tuple_int = (SHAPE, SHAPE, DIM)) -> Sequential:
"""
Machine learning model for image learning, here the purpose is segmentation,
thus there should be upsampling !!

Parameters
----------
pretrained_model: str
    name of .hdf5 file containing pretrained weights, syntax: 'dir:weight.hfd5'
input_size: tuple_int

Returns
-------
Sequential
"""
if pretrained_model:
    return read_model(pretrained_model)
else:
    inputs = Input(input_size)
    conv1 = Conv2D(BASE_SIZE, 3, activation='relu', padding='same', kernel_initializer=INITIALIZER)(inputs)
    batch_norm1 = BatchNormalization()(conv1)

    conv2 = Conv2D(BASE_SIZE, 3, activation='relu', padding='same', kernel_initializer=INITIALIZER)(batch_norm1)
    batch_norm2 = BatchNormalization()(conv2)
    pool1 = MaxPooling2D(pool_size=(2, 2))(batch_norm2)

    conv3 = Conv2D(BASE_SIZE * 2, 3, activation='relu', padding='same', kernel_initializer=INITIALIZER)(pool1)
    batch_norm3 = BatchNormalization()(conv3)

    conv4 = Conv2D(BASE_SIZE * 2, 3, activation='relu', padding='same', kernel_initializer=INITIALIZER)(batch_norm3)
    batch_norm4 = BatchNormalization()(conv4)
    pool2 = MaxPooling2D(pool_size=(2, 2))(batch_norm4)

    conv5 = Conv2D(BASE_SIZE * 4, 3, activation='relu', padding='same', kernel_initializer=INITIALIZER)(pool2)
    batch_norm5 = BatchNormalization()(conv5)

    conv6 = Conv2D(BASE_SIZE * 4, 3, activation='relu', padding='same', kernel_initializer=INITIALIZER)(batch_norm5)
    batch_norm6 = BatchNormalization()(conv6)
    pool3 = MaxPooling2D(pool_size=(2, 2))(batch_norm6)

    conv7 = Conv2D(BASE_SIZE * 8, 3, activation='relu', padding='same', kernel_initializer=INITIALIZER)(pool3)
    batch_norm7 = BatchNormalization()(conv7)

    conv8 = Conv2D(BASE_SIZE * 8, 3, activation='relu', padding='same', kernel_initializer=INITIALIZER)(batch_norm7)
    batch_norm8 = BatchNormalization()(conv8)

    drop4 = Dropout(0.2)(batch_norm8)
    pool4 = MaxPooling2D(pool_size=(2, 2))(drop4)

    conv9 = Conv2D(BASE_SIZE * 16, 3, activation='relu', padding='same', kernel_initializer=INITIALIZER)(pool4)
    batch_norm9 = BatchNormalization()(conv9)

    conv10 = Conv2D(BASE_SIZE * 16, 3, activation='relu', padding='same', kernel_initializer=INITIALIZER)(
        batch_norm9)
    batch_norm10 = BatchNormalization()(conv10)

    drop5 = Dropout(0.5)(batch_norm10)

    up6 = Conv2D(BASE_SIZE * 8, 2, activation='relu', padding='same', kernel_initializer=INITIALIZER)(
        UpSampling2D(size=(2, 2))(drop5))
    merge6 = concatenate([drop4, up6], axis=3)

    conv11 = Conv2D(BASE_SIZE * 8, 3, activation='relu', padding='same', kernel_initializer=INITIALIZER)(merge6)
    batch_norm11 = BatchNormalization()(conv11)

    conv12 = Conv2D(BASE_SIZE * 8, 3, activation='relu', padding='same', kernel_initializer=INITIALIZER)(
        batch_norm11)
    batch_norm12 = BatchNormalization()(conv12)

    up7 = Conv2D(BASE_SIZE * 4, 2, activation='relu', padding='same', kernel_initializer=INITIALIZER)(
        UpSampling2D(size=(2, 2))(batch_norm12))
    merge7 = concatenate([conv6, up7], axis=3)
    conv13 = Conv2D(BASE_SIZE * 4, 3, activation='relu', padding='same', kernel_initializer=INITIALIZER)(merge7)
    batch_norm13 = BatchNormalization()(conv13)

    conv14 = Conv2D(BASE_SIZE * 4, 3, activation='relu', padding='same', kernel_initializer=INITIALIZER)(
        batch_norm13)
    batch_norm14 = BatchNormalization()(conv14)

    up8 = Conv2D(BASE_SIZE * 2, 2, activation='relu', padding='same', kernel_initializer=INITIALIZER)(
        UpSampling2D(size=(2, 2))(batch_norm14))
    merge8 = concatenate([conv4, up8], axis=3)

    conv15 = Conv2D(BASE_SIZE * 2, 3, activation='relu', padding='same', kernel_initializer=INITIALIZER)(merge8)
    batch_norm15 = BatchNormalization()(conv15)

    conv16 = Conv2D(BASE_SIZE * 2, 3, activation='relu', padding='same', kernel_initializer=INITIALIZER)(
        batch_norm15)
    batch_norm16 = BatchNormalization()(conv16)

    up9 = Conv2D(BASE_SIZE, 2, activation='relu', padding='same', kernel_initializer=INITIALIZER)(
        UpSampling2D(size=(2, 2))(batch_norm16))
    merge9 = concatenate([conv2, up9], axis=3)

    conv17 = Conv2D(BASE_SIZE, 3, activation='relu', padding='same', kernel_initializer=INITIALIZER)(merge9)
    batch_norm17 = BatchNormalization()(conv17)

    conv18 = Conv2D(BASE_SIZE, 3, activation='relu', padding='same', kernel_initializer=INITIALIZER)(batch_norm17)
    batch_norm18 = BatchNormalization()(conv18)

    conv19 = Conv2D(2, 3, activation='relu', padding='same', kernel_initializer=INITIALIZER)(batch_norm18)
    batch_norm19 = BatchNormalization()(conv19)

    # personall add
    drop4 = Dropout(0.2)(batch_norm19)

    conv10 = Conv2D(1, 1, activation='sigmoid', name='MLP_layer')(drop4)

    model = Model(input=inputs, output=conv10)

    model.compile(optimizer=Adam(lr=LR),
                  loss='binary_crossentropy',
                  metrics=['accuracy', iou_loss])

    return model

谢谢


1
使用Unet时遇到了类似的问题。 BatchNormalization是必要的,同时还需要使用较低的batch_sizes和learning rates。 也很有兴趣得到一个正确的答案。 - Anakin
你能贴出用于创建网络的代码吗?你确定最后一层使用的是sigmoid函数吗? - Andrew Louw
你好,我刚刚添加了代码,是的,在MLP层上使用了'sigmoid' - sslloo
对我来说,学习率通常是1e-4左右,-3可能还可以,但是对于这种类型的网络,-2可能无法收敛。 - Simon Caby
据我所知,对于UNet来说,dropout并不是什么有趣的东西。BatchNorm才是。使用dropout可能会产生相反的效果。 - Simon Caby
显示剩余7条评论
1个回答

2
二元交叉熵在分割问题中表现不佳,特别是在存在类别不平衡的情况下。例如,如果掩膜中黑色像素比白色像素多得多,那么你的神经网络就会认为预测所有像素都是黑色是可以接受的。建议使用Dice loss或Jaccard loss作为目标函数,或者将Dice或Jaccard与二元交叉熵或加权二元交叉熵相结合。最后,你可以查看这个库https://segmentation-models.readthedocs.io/en/latest/install.html,其中包含一些分割模型,包括具有不同预训练编码器的Unet和此主题上最常见的指标(如Dice和Jaccard)。

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接