如何从tflite模型中获取权重?

3
2个回答

8
创建一个tflite解释器并(可选)进行推断。tflite_interpreter.get_tensor_details()将返回一个字典列表,其中包含权重、偏差、它们的缩放比例、零点等信息。
'''
Create interpreter, allocate tensors
'''
tflite_interpreter = tf.lite.Interpreter(model_path='model_file.tflite')
tflite_interpreter.allocate_tensors()

'''
Check input/output details
'''
input_details = tflite_interpreter.get_input_details()
output_details = tflite_interpreter.get_output_details()

print("== Input details ==")
print("name:", input_details[0]['name'])
print("shape:", input_details[0]['shape'])
print("type:", input_details[0]['dtype'])
print("\n== Output details ==")
print("name:", output_details[0]['name'])
print("shape:", output_details[0]['shape'])
print("type:", output_details[0]['dtype'])

'''
Run prediction (optional), input_array has input's shape and dtype
'''
tflite_interpreter.set_tensor(input_details[0]['index'], input_array)
tflite_interpreter.invoke()
output_array = tflite_interpreter.get_tensor(output_details[0]['index'])

'''
This gives a list of dictionaries. 
'''
tensor_details = tflite_interpreter.get_tensor_details()

for dict in tensor_details:
    i = dict['index']
    tensor_name = dict['name']
    scales = dict['quantization_parameters']['scales']
    zero_points = dict['quantization_parameters']['zero_points']
    tensor = tflite_interpreter.tensor(i)()

    print(i, type, name, scales.shape, zero_points.shape, tensor.shape)

    '''
    See note below
    '''
  • Conv2D层将有三个与之关联的字典:kernel(卷积核)、bias(偏置)和conv_output(卷积输出),每个字典都有它自己的scales(比例尺)、zero_points(零点)和张量(tensors)。
  • 张量(tensors)- 是np数组,包含卷积核权重或偏置。对于conv_output或激活函数,这并不代表任何意义(不是中间输出)。
  • 对于卷积核的字典,张量的形状为(cout, k, k, cin)

1

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接