如何使用K.clear_session()解决Keras中的内存泄漏问题?

3
我有一个网络正在使用model.train_on_batch()逐批次训练数据,如果我只运行此部分,我会看到我的网络在40个以上的时代(至今)中以3%的RAM利用率进行训练,每个时代大约有2000次迭代。当我尝试在每个时代之后进行验证(也是分批次进行的)时,就会出现非常严重的内存泄漏,导致90%的RAM利用率和我的代码被终止。所以在过去的几天里,我尝试了一些方法,似乎在循环中使用model.predict()会导致内存泄漏Tensorflow GitHub上的开放问题。我尝试了predict_on_batch(),结果相同。 model(inputs, training=False)似乎可以减缓内存泄漏,而不是从3% - 7% - 13% - 40% - 80% - 90%(60秒间隔)的突然跳变,它每分钟增加1%。但是某个时候它也达到了90%。这个Github线程中我唯一剩下要尝试的事情是使用K.clear_session()
我尝试阅读有关K.clear_session()的文档和一些SO帖子,所有这些都建议在创建多个模型时使用它,而我没有这样做。所以我的问题是,如果我有一个单独的正在训练和评估循环中的模型,在每个时代之后应该在哪里使用K.clear_session(),并在每个时代之前重新加载保存的模型吗?这样正确吗?
除此之外,我还遇到了拓扑排序错误另一个开放性问题,所以我想知道它是否是因为我正在循环中训练,因为我的代码否则没有循环,并且这种方式也会导致内存泄漏,K.clear_session()会在某种程度上有所帮助吗?
代码结构的最小示例:
from tensorflow.keras.models import Model
K = tf.keras.backend

def myModel():
    **some architecture**

ip = Input(shape=(h, w, 3))
op = myModel(ip)
model = Model(ip, op)
model.compile(optimizer=Adam(lr=1e-6), loss=custom_mean_squared_error)

for e in range(numEpochs):
    for batch in range(0, num_train_batches):
        x = readImages()
        y = readLabels()
        loss = model.train_on_batch(x, y)
        

    for batch in range(0, num_val_batches):
        x = readImages()
        y = model.predict(x)
        val_loss = K.get_value(custom_mean_squared_error(x,y))
        # save predictions

# plot training vs validation loss

Tensorflow-gpu-1.14,Python3.6。如果我做错了什么,希望能得到建议。

1个回答

0

这对我来说似乎有效,但会使过程变慢:

from tensorflow.keras.models import Model
K = tf.keras.backend

def myModel():
    **some architecture**

ip = Input(shape=(h, w, 3))
op = myModel(ip)
model = Model(ip, op)
model.compile(optimizer=Adam(lr=1e-6), loss=custom_mse)

for e in range(numEpochs):
    for batch in range(0, num_train_batches):
        x = readImages()
        y = readLabels()
        # define appropriate flags for first loop
        model = tf.keras.models.load_model(model_path,custom_objects={ 'custom_mse': custom_mse } )
        loss = model.train_on_batch(x, y)

    model.save(model_path)

    for batch in range(0, num_val_batches):
        x = readImages()
        model = tf.keras.models.load_model(model_path,custom_objects={ 'custom_mse': custom_mse } )
        y = model.predict(x)
        K.clear_session()
        val_loss = K.get_value(custom_mean_squared_error(x,y))
        # save predictions

# plot training vs validation loss

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接