从Tensorflow PrefetchDataset中提取目标

32

我仍在学习tensorflow和keras,而且我怀疑这个问题有一个非常简单的答案,只是因为缺乏熟悉而错过了。

我有一个PrefetchDataset对象:

> print(tf_test)
$ <PrefetchDataset shapes: ((None, 99), (None,)), types: (tf.float32, tf.int64)>

...由特征和目标组成。我可以使用for循环迭代它:

> for example in tf_test:
>     print(example[0].numpy())
>     print(example[1].numpy())
>     exit()
$ [[-0.31 -0.94 -1.12 ... 0.18 -0.27]
   [-0.22 -0.54 -0.14 ... 0.33 -0.55]
   [-0.60 -0.02 -1.41 ... 0.21 -0.63]
   ...
   [-0.03 -0.91 -0.12 ... 0.77 -0.23]
   [-0.76 -1.48 -0.15 ... 0.38 -0.35]
   [-0.55 -0.08 -0.69 ... 0.44 -0.36]]
  [0 0 1 0 1 0 0 0 1 0 1 1 0 1 0 0 0
   ...
   0 1 1 0]

然而,这非常缓慢。我想要做的是访问对应于类标签的张量,并将其转换为numpy数组、列表或任何可以输入scikit-learn分类报告和/或混淆矩阵的迭代器:

> y_pred = model.predict(tf_test)
> print(y_pred)
$ [[0.01]
   [0.14]
   [0.00]
   ...
   [0.32]
   [0.03]
   [0.00]]
> y_pred_list = [int(x[0]) for x in y_pred]             # assumes value >= 0.5 is positive prediction
> y_true = []                                           # what I need help with
> print(sklearn.metrics.confusion_matrix(y_true, y_pred_list)

...或者访问数据,以便可以在TensorFlow的混淆矩阵中使用:

> labels = []                                           # what I need help with
> predictions = y_pred_list                             # could we just use a tensor?
> print(tf.math.confusion_matrix(labels, predictions)
在这两种情况下,能够以不消耗大量计算资源的方式从原始对象中获取目标数据的一般能力将非常有帮助(并可能有助于我的有关 TensorFlow 和 Keras 的基本直觉)。任何建议都将不胜感激。

3
这是答案:y = np.concatenate([y for x, y in ds], axis=0) - prashanth
9个回答

14

您可以使用list(ds)将其转换为列表,然后使用tf.data.Dataset.from_tensor_slices(list(ds))将其重新编译为普通Dataset。从那里开始,您的噩梦又开始了,但至少这是其他人曾经经历过的噩梦。

请注意,对于更复杂的数据集(例如嵌套字典),在调用list(ds)之后,您需要进行更多的预处理,但对于您所询问的示例,这应该可以工作。

这远非一个令人满意的答案,但不幸的是,这个类完全没有文档记录,并且没有任何标准的Dataset技巧可用。


1
我知道这个类完全没有文档,但是这个答案很好 :) - DevLoverUmar
@DevLoverUmar 很高兴能帮到你 :) - markemus

8

您可以使用map来选择每个(input, label)对中的输入或标签,并将其转换为列表:

import tensorflow as tf
import numpy as np

inputs = np.random.rand(100, 99)
targets = np.random.rand(100)

ds = tf.data.Dataset.from_tensor_slices((inputs, targets))

X_train = list(map(lambda x: x[0], ds))
y_train = list(map(lambda x: x[1], ds))

4

您可以通过循环遍历PrefetchDataset(在我的示例中是train_dataset)来生成列表;

train_data = [(example.numpy(), label.numpy()) for example, label in train_dataset]

因此,您可以通过使用索引单独访问每个示例和标签。
train_data[0][0]
train_data[0][1]

您也可以使用pandas将它们转换为具有2个列的数据框

import pandas as pd
pd.DataFrame(train_data, columns=['example', 'label'])

那么,如果您想将列表转换回 PrefetchDataset,您可以简单地使用 ;

dataset = tf.data.Dataset.from_generator(
lambda: train_data, ( tf.string, tf.int32)) # you should define dtypes of yours

您可以使用以下内容检查是否起作用:

list(dataset.as_numpy_iterator())

这很简单,但在我的情况下非常有效。 - panoet

4
如果您想保留批次或提取所有标签作为单个张量,可以使用以下函数:

def get_labels_from_tfdataset(tfdataset, batched=False):

    labels = list(map(lambda x: x[1], tfdataset)) # Get labels 

    if not batched:
        return tf.concat(labels, axis=0) # concat the list of batched labels

    return labels

2

我有一个类似的问题,对象如下:

type(train_ds)
>> tensorflow.python.data.ops.dataset_ops.PrefetchDataset

我成功地从一个批次中提取了特征和标签,方法如下:
[(train_features, label_batch)] = train_ds.take(1)
print(np.array(label_batch))

2

这是由Dataset.prefetch()方法返回的类,它是Dataset的一个子类。

如果您通过传递ReadConfig到构建器来设置skip_prefetch=Ture,则返回的类型将是_OptionsDataset。

read_config = tfds.ReadConfig(skip_prefetch = True)
dataset_builder.as_dataset(
    ......,
    read_config = read_config,
)

https://www.tensorflow.org/api_docs/python/tf/data/Dataset#prefetch


1
如果您使用批次创建了tf.data.Dataset并且希望获得两个单独的numpy数组,则将每个列表的列表连接成一个单独的数组。
train_data = list(train_ds)
features = np.concatenate([train_data[n][0] for n in range(0, len(train_data))])
targets = np.concatenate([train_data[n][1] for n in range(0, len(train_data))])


0
您可以使用map()函数一次迭代
ratings.map(lambda x: x["feature name"])

0

对于图像数据集,Tensorflow 2.12.0

加载数据

 dataset, dataset_info = tfds.load('malaria', with_info=True, as_supervised=True,shuffle_files=True,split["train"],data_dir="you_dir\tensorflow_datasets\\")

迭代遍历样本

  for i, (image, label) in enumerate(train_dataset.take(16)):
        ax = plt.subplot(4, 4, i+1)
        plt.imshow(image)
        plt.title(dataset_info.features['label'].int2str(label))
        plt.axis('off')

plt.show()

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接