可训练参数的数量 - Python / TensorFlow 的卷积神经网络

6
在TensorFlow中,是否有任何函数或方法可以帮助我找出网络中的学习参数数量?
2个回答

6

我不知道有任何函数可以实现此功能,但您仍然可以使用 tf.trainable_variables() 上的 for 循环来计算自己:

total_parameters = 0
for variable in tf.trainable_variables():
    variable_parameters = 1
    for dim in variable.get_shape():
        variable_parameters *= dim.value
    total_parameters += variable_parameters

print("Total number of trainable parameters: %d" % total_parameters)

知道了!谢谢!那么把这段代码放在哪个位置最好呢?在初始化TensorFlow会话之后吗? - whoisraibolt
1
你可以将它放在会话初始化之前。它只读取图变量,不需要 tf 会话。 - Pop

2
您可以使用简单的一行代码来完成这个任务:
np.sum([np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()])

如果您需要更详细的信息,这里是我使用的一个辅助函数,可以查看所有可训练参数:
def show_params():
  total = 0
  for v in tf.trainable_variables():
    dims = v.get_shape().as_list()
    num  = int(np.prod(dims))
    total += num
    print('  %s \t\t Num: %d \t\t Shape %s ' % (v.name, num, dims))
  print('\nTotal number of params: %d' % total)

它会输出像这样的信息:
  params/weights/W1:0        Num: 34992      Shape [3, 3, 18, 216] 
  params/weights/W2:0        Num: 839808     Shape [3, 3, 216, 432] 
  params/weights/W3:0        Num: 839808     Shape [3, 3, 432, 216] 
  params/weights/W4:0        Num: 57856      Shape [226, 256] 
  params/weights/W5:0        Num: 32768      Shape [256, 128] 
  params/weights/W6:0        Num: 8192       Shape [128, 64] 
  params/weights/W7:0        Num: 64         Shape [64, 1] 
  params/biases/b1:0         Num: 216        Shape [216] 
  params/biases/b2:0         Num: 432        Shape [432] 
  params/biases/b3:0         Num: 216        Shape [216] 
  params/biases/b4:0         Num: 256        Shape [256] 
  params/biases/b5:0         Num: 128        Shape [128] 
  params/biases/b6:0         Num: 64         Shape [64] 
  params/biases/b7:0         Num: 1          Shape [1]

Total number of params: 1814801

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接