将Keras模型.summary()对象转换为字符串。

73

我想用神经网络超参数和模型架构编写一个*.txt文件。是否可以将对象model.summary()写入我的输出文件中?

(...)
summary = str(model.summary())
(...)
out = open(filename + 'report.txt','w')
out.write(summary)
out.close

我得到了 "None",就像你下面看到的那样。

Hyperparameters
=========================

learning_rate: 0.01
momentum: 0.8
decay: 0.0
batch size: 128
no. epochs: 3
dropout: 0.5
-------------------------

None
val_acc: 0.232323229313
val_loss: 3.88496732712
train_acc: 0.0965207634216
train_loss: 4.07161939425
train/val loss ratio: 1.04804469418

有什么办法来处理这个问题吗?


如何以字典格式获取输出:https://stackoverflow.com/a/68128858/10375049 - Marco Cerliani
9个回答

78

我使用的Keras版本是(2.0.6)和Python版本是(3.5.0),对我来说这个可以工作:

# Create an empty model
from keras.models import Sequential
model = Sequential()

# Open the file
with open(filename + 'report.txt','w') as fh:
    # Pass the file handle in as a lambda function to make it callable
    model.summary(print_fn=lambda x: fh.write(x + '\n'))

这将向文件输出以下行:

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
Total params: 0
Trainable params: 0
Non-trainable params: 0
_________________________________________________________________

4
版本 2.0.6 新增功能。 - Huo
1
仅作为补充说明:您还可以使用line_length作为额外参数来调整输出行的长度。这很有用,因为有时层名称太长,因此会被截断:model.summary(line_length=80, print_fn=lambda x: fh.write(x + '\n')) - iipr

36

对我来说,这个方法可以将模型摘要作为字符串获取:

stringlist = []
model.summary(print_fn=lambda x: stringlist.append(x))
short_model_summary = "\n".join(stringlist)
print(short_model_summary)

24

如果您想要写入日志:

import logging
logger = logging.getLogger(__name__)

model.summary(print_fn=logger.info)

12

我知道OP已经接受了winni2k的答案,但是由于问题标题实际上暗示着将model.summary()的输出保存到一个字符串中,而不是文件,所以下面的代码可能会帮助其他人找到这个页面(就像我一样)。

下面的代码是在TensorFlow 1.12.0上运行的,它自带了Python 3.6.2上的Keras 2.1.6-tf

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Activation
import io

# Example model
model = Sequential([
    Dense(32, input_shape=(784,)),
    Activation('relu'),
    Dense(10),
    Activation('softmax'),
])

def get_model_summary(model):
    stream = io.StringIO()
    model.summary(print_fn=lambda x: stream.write(x + '\n'))
    summary_string = stream.getvalue()
    stream.close()
    return summary_string

model_summary_string = get_model_summary(model)

print(model_summary_string)

生成的字符串为:

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense (Dense)                (None, 32)                25120     
_________________________________________________________________
activation (Activation)      (None, 32)                0         
_________________________________________________________________
dense_1 (Dense)              (None, 10)                330       
_________________________________________________________________
activation_1 (Activation)    (None, 10)                0         
=================================================================
Total params: 25,450
Trainable params: 25,450
Non-trainable params: 0
_________________________________________________________________

10

我也遇到了同样的问题!有两种可能的解决方法:

使用模型的 to_json() 方法

summary = str(model.to_json())

这是您上面的案例。

否则,请使用keras_diagram中的ascii方法。

from keras_diagram import ascii
summary = ascii(model)

8

6

有一个选项,虽然不是 model.summary 的完全替代品,但可以使用 model.get_config() 导出模型的配置。参考文档

model.get_config(): returns a dictionary containing the configuration of the model. The model can be reinstantiated from its config via:

config = model.get_config()
model = Model.from_config(config)
# or, for Sequential:
model = Sequential.from_config(config)

2

我来这里是为了找到一种记录摘要的方法,我想分享一下对@ajb答案的小改动,以避免在日志文件中每行都出现INFO:,具体方法可以参考@FAnders的答案:

def get_model_summary(model: tf.keras.Model) -> str:
    string_list = []
    model.summary(line_length=80, print_fn=lambda x: string_list.append(x))
    return "\n".join(string_list)

# some code
logging.info(get_model_summary(model)

产生一个日志文件如下: enter image description here


1
我曾经遇到过同样的问题。@Pasa的回答非常有用,但我想发表一个更简化的例子:在这一点上,您已经拥有了一个Keras模型,这是一个合理的假设。
import io

s = io.StringIO()
model.summary(print_fn=lambda x: s.write(x + '\n'))
model_summary = s.getvalue()
s.close()

print("The model summary is:\n\n{}".format(model_summary))

当拥有这个字符串时的一个例子:如果你有一个 matplotlib 绘图。你可以使用以下代码:
plt.text(0, 0.25, model_summary)

为了快速参考,将您的模型摘要写入性能图表中: 输入图像描述

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接