属性错误:'Tensor'对象没有属性'reshape'。

3
我希望能编写一个去噪自编码器,为了可视化目的,我想打印出受损图像。
这是测试部分,我想展示受损图像:
def corrupt(x):
    noise = tf.random_normal(shape=tf.shape(x), mean=0.0, stddev=0.2, dtype=tf.float32) 
    return x + noise

# Testing
# Encode and decode images from test set and visualize their reconstruction
n = 10
canvas_orig = np.empty((28, 28 * n))
canvas_corrupt = np.empty((28, 28 * n))
canvas_recon = np.empty((28, 28 * n))

# MNIST test set
batch_x, _ = mnist.test.next_batch(n)

# Encode and decode the digit image and determine the test loss
g, l = sess.run([Y, loss], feed_dict={X: batch_x})

# Draw the generated digits
for i in range(n):
    # Original images
    canvas_orig[0: 28, i * 28: (i + 1) * 28] = batch_x[i].reshape([28, 28])

    # Corrupted images
    canvas_corrupt[0: 28, i * 28: (i + 1) * 28] = corrupt(batch_x[i]).reshape([28, 28]) 

    # Reconstructed images
    canvas_recon[0: 28, i * 28: (i + 1) * 28] = g[i].reshape([28, 28])    

print("Original Images")     
plt.figure(figsize=(n, 1))
plt.imshow(canvas_orig, origin="upper", cmap="gray")
plt.show()

print("Corrupted Images")     
plt.figure(figsize=(n, 1))
plt.imshow(canvas_corrupt, origin="upper", cmap="gray")
plt.show()

print("Reconstructed Images")
plt.figure(figsize=(n, 1))
plt.imshow(canvas_recon, origin="upper", cmap="gray")
plt.show()

错误出现在以下行:
canvas_corrupt[0: 28, i * 28: (i + 1) * 28] = corrupt(batch_x[i]).reshape([28, 28])

我真的不明白为什么它不能工作,因为它上下的语句看起来非常相似,而且完美地工作。 而“reshape”是一个函数而不是属性的事实,非常让我困惑。

1个回答

1
区别在于batch_x[i]是一个numpy数组(具有reshape方法),而corrupt(...)的结果是一个Tensor对象。在tf 1.5中,它没有reshape方法。这不会引发错误:tf.reshape(corrupt(batch_x[i]), [28, 28])
但由于您的目标是可视化该值,因此最好避免混合使用tensorflow和numpy操作,并仅使用numpy重写corrupt函数:
def corrupt(x):
    noise = np.random.normal(size=x.shape, loc=0.0, scale=0.2)
    return x + noise

谢谢,这对我很有帮助。正如你所说,现在我只需使用 numpy 生成噪声,使用 np.random.normal(0, 0.2, 784) 即可,因此再也没有 reshape() 函数的问题了。 - wagnrd

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接