MXNet打印中间符号值

5

我该如何找到MXNet符号中保存的实际数值。

假设我有以下代码:

x = mx.sym.Variable('x')
y = mx.sym.Variable('y')
z = x + y, 

如果 x=[100,200],y=[300,400],我想打印出:z=[400,600],就像 TensorFlow 的 eval() 方法一样。
1个回答

8

在稍微查找了一下之后,我发现可以通过以下方式实现:

x = mx.sym.Variable('x')
y = mx.sym.Variable('y')
z = x + y
executor = z.bind(mx.cpu(), {'x': mx.nd.array([100,200]), 'y':mx.nd.array([300,400])})
output = executor.forward()

将会给您提供“输出”:

[<NDArray 2 @cpu(0)>]

打印实际的数字输出:

print output[0].asnumpy()
array([ 400.,  600.], dtype=float32)

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接