如何可视化TensorFlow Estimator的权重?

7
2个回答

9

估算器(Estimator)有一个方法叫做get_variable_value。因此,一旦您生成了检查点(或从检查点中加载变量值),并且如果您知道密集层的名称,您可以使用matplotlib执行以下操作:

import matplotlib.pyplot as plt

weights = estimator.get_variable_value('dense/kernel')
plt.imshow(weights, cmap='gray')
plt.show()

你怎么知道变量名?例如,当使用带有FTRL优化器的LinearRegressor Estimator时,使用您链接的页面上的get_variable_names()方法,我得到了如下变量名:'global_step', 'linear/linear_model/X0/weights', 'linear/linear_model/X0/weights/part_0/Ftrl', 'linear/linear_model/X0/weights/part_0/Ftrl_1', 'linear/linear_model/X1/weights', 'linear/linear_model/X1/weights/part_0/Ftrl', 'linear/linear_model/X1/weights/part_0/Ftrl_1', 'linear/linear_model/X2/weights', 'linear/linear_model/X2/weights/part_0/Ftrl', 'linear/linear_model/X2/weights/part_0/Ftrl_1'等。 - ac2051
@ac2051 我能修改这些权重吗?我想要一个像set_variable_value的方法。 - Yashu Seth
@natkinis 我能修改这些权重吗?我想要一个类似于 set_variable_value 的方法。 - Yashu Seth

3

我刚刚使用了预编译的Estimator进行测试,结果正常。

import matplotlib.pyplot as plt

names = classifier.get_variable_names()        
print("name:", names)
for i in names:
    print(classifier.get_variable_value(i)

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接