TensorFlow 1.7+Keras和数据集:对象没有属性'ndim'

3
调用 keras 的 model.fit() 方法时,我遇到了以下错误: AttributeError: 'RepeatDataset' object has no attribute 'ndim' 我正在使用 TensorFlow 1.7 和 Keras。不幸的是,我必须使用 TF 1.7。你有什么想法吗?这段代码是从一个 TensorFlow 演示中 改编 而来的。
import tensorflow as tf
from IPython import embed
from tensorflow.python import keras
from tensorflow.python.keras import layers

model = tf.keras.Sequential()
model.add(layers.Dense(64, input_shape=(32,), activation='relu'))
model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dense(10, activation='softmax'))

model.compile(
    optimizer=tf.train.AdamOptimizer(0.001),
    loss='categorical_crossentropy',
    metrics=['accuracy'])

import numpy as np

# Generate random data using numpy
def random_one_hot_labels(shape):
    n, n_class = shape
    classes = np.random.randint(0, n_class, n)
    labels = np.zeros((n, n_class))
    labels[np.arange(n), classes] = 1
    return labels

data = np.random.random((1000, 32))
labels = random_one_hot_labels((1000, 10))

datasetA = tf.data.Dataset.from_tensor_slices((data, labels))
datasetB = datasetA.batch(32)
dataset = datasetB.repeat()

model.fit(
    dataset, 
    epochs=10,
    steps_per_epoch=30
)

你可以尝试把参数传递到 repeat(),例如repeat(count=2) - Ashwin Geet D'Sa
另外,在创建repeat()之后,您可以将批次分组。 - Ashwin Geet D'Sa
删除 datasetB = datasetA.batch(32) 并添加 dataset = datasetA.repeat(<epochs>).batch(32) - Ashwin Geet D'Sa
@AshwinGeetD'Sa 感谢您的想法。不,它们会产生相同的错误(再次使用TF 1.7)。 - Robert Lugg
1个回答

1
这个错误是因为repeat()返回了一个生成器,你将其传递给了fitfit期望一个已定义ndim的numpy数组。后来添加了对带有fit的生成器的支持。尝试使用现在已被弃用的fit_generator代替:
model.fit_generator(
    dataset, 
    epochs=10,
    steps_per_epoch=30
)

请注意,如果没有任何参数,repeat() 将使用 -1,这可能是您要寻找的行为,也可能不是。像 repeat(1)repeat(2) 这样的内容可能是您要寻找的。截至版本 1.7 发布时的 RepeatDataset 来源:
class RepeatDataset(Dataset):
  """A `Dataset` that repeats its input several times."""

  def __init__(self, input_dataset, count):
    """See `Dataset.repeat()` for details."""
    super(RepeatDataset, self).__init__()
    self._input_dataset = input_dataset
    if count is None:
      self._count = constant_op.constant(-1, dtype=dtypes.int64, name="count")
    else:
      self._count = ops.convert_to_tensor(
          count, dtype=dtypes.int64, name="count")

我尝试复制它,但需要更多的努力才能安装正确版本。

如果这样不起作用,可能值得尝试手动迭代数据集生成器并首先从中创建一个numpy数组,然后将其传递给fit。 我不确定在1.7中是否有一种Keras方法可以做到这一点,但是如果你必须走这条路,this answer可能会有用。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接