让你的代码在TF 2.0中发挥作用
以下是一个示例代码,可以在TF 2.0中使用。
它依赖于兼容性API,
该API可作为tensorflow.compat.v1
访问,并需要禁用v2行为。
如果它不能按照您的预期运行,请向我们提供更多说明以便更好地帮助您。
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
@tf.function
def construct_graph(graph_dict, inputs, outputs):
queue = inputs[:]
make_dict = {}
for key, val in graph_dict.items():
if key in inputs:
make_dict[key] = tf.placeholder(tf.float32, name=key)
else:
make_dict[key] = None
while len(queue) != 0:
cur = graph_dict[queue[0]]
for outg in cur["outgoing"]:
if make_dict[outg[0]]:
make_dict[outg[0]] = tf.add(make_dict[outg[0]], tf.multiply(outg[1], make_dict[queue[0]]))
else:
make_dict[outg[0]] = tf.multiply(make_dict[queue[0]], outg[1])
for outgo in graph_dict[outg[0]]["outgoing"]:
queue.append(outgo[0])
queue.pop(0)
return [make_dict[x] for x in outputs]
def main():
graph_def = {
"B": {
"incoming": [],
"outgoing": [("A", 1.0)]
},
"C": {
"incoming": [],
"outgoing": [("A", 1.0)]
},
"A": {
"incoming": [("B", 2.0), ("C", -1.0)],
"outgoing": [("D", 3.0)]
},
"D": {
"incoming": [("A", 2.0)],
"outgoing": []
}
}
outputs = construct_graph(graph_def, ["B", "C"], ["A"])
print(outputs)
if __name__ == "__main__":
main()
[<tf.Tensor 'PartitionedCall:0' shape=<unknown> dtype=float32>]
将代码迁移到TF 2.0
虽然上面的代码片段是有效的,但它仍然与TF 1.0相关联。
要将其迁移到TF 2.0,您需要稍微重构一下代码。
我建议您不要返回可调用的张量列表,而是返回一个keras.layers.Model
列表。
下面是一个有效的示例:
import tensorflow as tf
def construct_graph(graph_dict, inputs, outputs):
queue = inputs[:]
make_dict = {}
for key, val in graph_dict.items():
if key in inputs:
make_dict[key] = tf.keras.Input(name=key, shape=(), dtype=tf.dtypes.float32)
else:
make_dict[key] = None
while len(queue) != 0:
cur = graph_dict[queue[0]]
for outg in cur["outgoing"]:
if make_dict[outg[0]] is not None:
make_dict[outg[0]] = tf.keras.layers.add([
make_dict[outg[0]],
tf.keras.layers.multiply(
[[outg[1]], make_dict[queue[0]]],
)],
)
else:
make_dict[outg[0]] = tf.keras.layers.multiply(
[make_dict[queue[0]], [outg[1]]]
)
for outgo in graph_dict[outg[0]]["outgoing"]:
queue.append(outgo[0])
queue.pop(0)
model_inputs = [make_dict[key] for key in inputs]
model_outputs = [make_dict[key] for key in outputs]
return [tf.keras.Model(inputs=model_inputs, outputs=o) for o in model_outputs]
def main():
graph_def = {
"B": {
"incoming": [],
"outgoing": [("A", 1.0)]
},
"C": {
"incoming": [],
"outgoing": [("A", 1.0)]
},
"A": {
"incoming": [("B", 2.0), ("C", -1.0)],
"outgoing": [("D", 3.0)]
},
"D": {
"incoming": [("A", 2.0)],
"outgoing": []
}
}
outputs = construct_graph(graph_def, ["B", "C"], ["A"])
print("Builded models:", outputs)
for o in outputs:
o.summary(120)
print("Output:", o((1.0, 1.0)))
if __name__ == "__main__":
main()
需要注意什么?
- 将
placeholder
更改为keras.Input
,需要设置输入的形状。 - 使用
keras.layers.[add|multiply]
进行计算。这可能不是必需的,但要坚持一个接口。然而,它需要在列表中包装因子(以处理批处理) - 构建
keras.Model
并返回 - 使用值元组调用模型(不再是字典)
以下是代码的输出。
Builded models: [<tensorflow.python.keras.engine.training.Model object at 0x7fa0b49f0f50>]
Model: "model"
________________________________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
========================================================================================================================
B (InputLayer) [(None,)] 0
________________________________________________________________________________________________________________________
C (InputLayer) [(None,)] 0
________________________________________________________________________________________________________________________
tf_op_layer_mul (TensorFlowOpLayer) [(None,)] 0 B[0][0]
________________________________________________________________________________________________________________________
tf_op_layer_mul_1 (TensorFlowOpLayer) [(None,)] 0 C[0][0]
________________________________________________________________________________________________________________________
add (Add) (None,) 0 tf_op_layer_mul[0][0]
tf_op_layer_mul_1[0][0]
========================================================================================================================
Total params: 0
Trainable params: 0
Non-trainable params: 0
________________________________________________________________________________________________________________________
Output: tf.Tensor([2.], shape=(1,), dtype=float32)
tf.function
中无法转换吗? - jdehesainputs
和outputs
里面有什么? - thushv89graph_dict
,inputs
和outputs
的结果,那将非常有帮助。 - thushv89