如何在tf.data.Dataset.map()中使用Keras的predict_on_batch()?

3
我希望找到一种方法,在TF2.0中使用Keras的predict_on_batch函数放在tf.data.Dataset.map()函数内。假设我有一个numpy数据集。
n_data = 10**5
my_data    = np.random.random((n_data,10,1))
my_targets = np.random.randint(0,2,(n_data,1))

data = ({'x_input':my_data}, {'target':my_targets})

以及一个tf.keras模型

x_input = Input((None,1), name = 'x_input')
RNN     = SimpleRNN(100,  name = 'RNN')(x_input)
dense   = Dense(1, name = 'target')(RNN)

my_model = Model(inputs = [x_input], outputs = [dense])
my_model.compile(optimizer='SGD', loss = 'binary_crossentropy')

我可以创建具有批处理功能的数据集
dataset = tf.data.Dataset.from_tensor_slices(data)
dataset = dataset.batch(10)
prediction_dataset = dataset.map(transform_predictions)

其中,transform_predictions 是一个用户定义的函数,用于从 predict_on_batch 中获取预测结果。

def transform_predictions(inputs, outputs):
    predictions = my_model.predict_on_batch(inputs)
    # predictions = do_transformations_here(predictions)
    return predictions

当使用predict_on_batch函数时会提示以下错误:

AttributeError: 'Tensor' object has no attribute 'numpy'

据我所知,predict_on_batch函数需要一个numpy数组作为输入参数,而数据集给出了一个张量对象。

似乎解决的一个可能方法是将predict_on_batch函数放在`tf.py_function`中,但我也无法使其正常工作。

有人知道如何解决吗?


这里有一个类似的问题是用R写的链接,但没有解决方案。 - Anders
1个回答

5

Dataset.map() 返回的是 <class 'tensorflow.python.framework.ops.Tensor'>,它没有 numpy() 方法。

遍历 Dataset 返回的是 <class 'tensorflow.python.framework.ops.EagerTensor'>,它有一个 numpy() 方法。

将 EagerTensor 作为预测方法的输入是可以正常工作的。

您可以尝试以下代码:

dataset = tf.data.Dataset.from_tensor_slices(data)
dataset = dataset.batch(10)

for x,y in dataset:
    predictions = my_model.predict_on_batch(x['x_input'])
    #or 
    predictions = my_model.predict_on_batch(x)

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接