如何计算tensorflow模型中可训练参数的总数?

74

有没有一种函数调用或其他方式来计算 TensorFlow 模型中参数的总数?

在这里,参数指的是:可训练变量的 N 维向量具有 N 个参数,一个 NxM 矩阵具有 N*M 个参数,依此类推。因此,我想对 TensorFlow 会话中所有可训练变量的形状尺寸的积求和。


你的问题描述和标题不匹配(除非我混淆了图形和模型的术语)。在问题中,你问到了一个图形,而在标题中,你问到了一个模型。如果你有两个不同的模型怎么办?我建议在问题中澄清这一点。 - Charlie Parker
如果您正在使用Keras,请参考以下链接:https://dev59.com/m1cO5IYBdhLWcg3w7lrq - bers
9个回答

93

循环遍历tf.trainable_variables()中每个变量的形状。

total_parameters = 0
for variable in tf.trainable_variables():
    # shape is an array of tf.Dimension
    shape = variable.get_shape()
    print(shape)
    print(len(shape))
    variable_parameters = 1
    for dim in shape:
        print(dim)
        variable_parameters *= dim.value
    print(variable_parameters)
    total_parameters += variable_parameters
print(total_parameters)

更新:由于这个答案,我写了一篇文章来澄清Tensorflow中动态/静态形状的问题:https://pgaleone.eu/tensorflow/2018/07/28/understanding-tensorflow-tensors-shape-static-dynamic/


5
如果你有多个模型,tf.trainable_variables() 如何知道要使用哪一个? - Charlie Parker
2
tf.trainable_variables() 返回当前图中标记为可训练的所有变量。如果在当前图中有多个模型,则必须使用它们的名称手动过滤变量。类似于 if variable.name.startswith("model2"): ... - nessuno
这个解决方案给了我一个错误提示:“异常发生:无法隐式地将'int'对象转换为'str'”。你需要按照下面的答案建议显式地将“dim”转换为“int”(我认为这是正确的答案)。 - whiletrue
非常有帮助。 - Sudip Das
在TF2中,似乎已经改为tf.compat.v1.trainable_variables()!但是这会返回0个参数! - chikitin
3
在TensorFlow 2中,这个答案已经被弃用了。你必须使用你的Keras模型的.trainable_variables - 不再有全局图! - nessuno

50

我有一个更短的版本,使用numpy实现一行代码解决:

np.sum([np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()])

在我的版本中,v 没有 shape_as_list() 函数,只有 get_shape() 函数。 - mustafa
我认为早期版本没有.shape而是使用get_shape()。我更新了我的回答。不管怎样,我写的是v.shape.as_list()而不是v.shape_as_list()。 - Michael Gygli
14
在TensorFlow 1.2中,np.sum([np.prod(v.shape) for v in tf.trainable_variables()])同样有效。 说明:该代码用于计算TensorFlow模型中可训练参数的数量。 - Julius Kunze
1
np.sum([np.prod(v.shape) for v in model.trainable_variables]) 这行代码在我这儿能运行,末尾不需要加上函数调用 "()". - Manute

9

不确定给出的答案是否实际运行(我发现您需要将dim对象转换为int才能使其正常工作)。这里是一个可行的解决方案,您可以直接复制粘贴这些函数并调用它们(还添加了一些注释):

def count_number_trainable_params():
    '''
    Counts the number of trainable variables.
    '''
    tot_nb_params = 0
    for trainable_variable in tf.trainable_variables():
        shape = trainable_variable.get_shape() # e.g [D,F] or [W,H,C]
        current_nb_params = get_nb_params_shape(shape)
        tot_nb_params = tot_nb_params + current_nb_params
    return tot_nb_params

def get_nb_params_shape(shape):
    '''
    Computes the total number of params for a given shap.
    Works for any number of shapes etc [D,F] or [W,H,C] computes D*F and W*H*C.
    '''
    nb_params = 1
    for dim in shape:
        nb_params = nb_params*int(dim)
    return nb_params 

答案确实可行(r0.11.0)。你的更加即插即用 :) - f4.
@f4. 这似乎存在一个错误,因为 y 似乎没有被使用。 - Charlie Parker
1
@CharlieParker 我几秒钟前已经修复了它 ;) - f4.
@f4。它仍然没有真正解决我想做的问题(或者原作者本意,因为他给了y作为输入),因为我正在寻找一个依赖于所给模型(即“y”)的函数。现在,如所给,我不知道它到底计算什么。我的怀疑是它只计算所有模型(我有两个单独的模型)。 - Charlie Parker
@CharlieParker 它会计算所有可训练的变量,默认情况下是所有变量。您可以使用变量属性(如图形或名称)来解决一些问题。 - f4.

7

2020年4月更新:tfprof和Profiler UI已被淘汰,推荐使用TensorBoard中的分析器支持

如果你想自行计算参数数量,两个现有答案都很好。如果你的问题更多地是“有没有一种简单的方法来对我的TensorFlow模型进行性能剖析?”,我强烈建议你看看tfprof。它可以对你的模型进行性能剖析,包括计算参数数量。


tfprof的链接已经失效。由于编辑队列已满,这里提供有效链接。另外,tfprof已被弃用。 - dsalaj

3
我会提供我等价但更短的实现:

我会提供我等价但更短的实现:

def count_params():
    "print number of trainable variables"
    size = lambda v: reduce(lambda x, y: x*y, v.get_shape().as_list())
    n = sum(size(v) for v in tf.trainable_variables())
    print "Model size: %dK" % (n/1000,)

3
如果想要避免使用numpy(在许多项目中可以省略它),则可以这样做:
all_trainable_vars = tf.reduce_sum([tf.reduce_prod(v.shape) for v in tf.trainable_variables()])

这是对Julius Kunze之前回答的TF翻译。
与任何TF操作一样,它需要一个会话运行来进行评估:
print(sess.run(all_trainable_vars))

3

TF v2.9上对我有效。感谢这个答案

import numpy as np

trainable_params = np.sum([np.prod(v.get_shape()) for v in model.trainable_weights])
non_trainable_params = np.sum([np.prod(v.get_shape()) for v in model.non_trainable_weights])
total_params = trainable_params + non_trainable_params
    
print(trainable_params)
print(non_trainable_params)
print(total_params)

3
现在,您可以使用以下内容:
from keras.utils.layer_utils import count_params  

count_params(model.trainable_weights)

-1
model.summary()

Model: "sequential_32"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d_88 (Conv2D)           (None, 240, 240, 16)      448       
_________________________________________________________________
max_pooling2d_87 (MaxPooling (None, 120, 120, 16)      0         
_________________________________________________________________
conv2d_89 (Conv2D)           (None, 120, 120, 32)      4640      
_________________________________________________________________
max_pooling2d_88 (MaxPooling (None, 60, 60, 32)        0         
_________________________________________________________________
conv2d_90 (Conv2D)           (None, 60, 60, 64)        18496     
_________________________________________________________________
max_pooling2d_89 (MaxPooling (None, 30, 30, 64)        0         
_________________________________________________________________
flatten_29 (Flatten)         (None, 57600)             0         
_________________________________________________________________
dropout_48 (Dropout)         (None, 57600)             0         
_________________________________________________________________
dense_150 (Dense)            (None, 24)                1382424   
_________________________________________________________________
dense_151 (Dense)            (None, 9)                 225       
_________________________________________________________________
dense_152 (Dense)            (None, 3)                 30        
_________________________________________________________________
dense_153 (Dense)            (None, 1)                 4         
=================================================================
Total params: 1,406,267
Trainable params: 1,406,267
Non-trainable params: 0
_________________________________________________________________

2
这是Keras,不是Tensorflow;问题显然是关于Tensorflow模型而不是Keras模型的。 - desertnaut

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接