如何使用TensorFlow 2.0洗牌两个NumPy数据集?

11

我希望能够在TensorFlow 2.0中编写一个函数,在每个训练迭代之前对数据和目标标签进行洗牌。

假设我有两个numpy数据集X和y,分别表示分类的数据和标签。如何同时随机打乱它们

使用sklearn很容易:

from sklearn.utils import shuffle
X, y = shuffle(X, y)

TensorFlow 2.0中我该如何做到相同的功能?文档中唯一找到的工具是tf.random.shuffle,但它每次只能处理一个对象,而我需要同时处理两个对象。


看看这个能否帮到你 - https://dev59.com/v1gQ5IYBdhLWcg3wJwkv. - Divakar
3个回答

6

与其混淆x和y,更容易混淆它们的索引,因此首先生成索引列表

indices = tf.range(start=0, limit=tf.shape(x_data)[0], dtype=tf.int32)

然后混淆这些索引

idx = tf.random.shuffle(indices)

并使用这些索引来混淆数据

x_data = tf.gather(x_data, idx)
y_data = tf.gather(y_data, idx)

然后您将获得洗牌后的数据


3

首先将它们转换为tf.data.Dataset类型。

x_train = tf.data.Dataset.from_tensor_slices(x)
y_train = tf.data.Dataset.from_tensor_slices(y)

完成这一步后,您可以轻松地对它们进行随机排列:

x_train, y_train = x_train.shuffle(buffer_size=2, seed=2), y_train.shuffle(buffer_size=2, seed=2)
dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))

在训练变量中使用相同的 seed,这样您就可以在不失去特征-目标关系的情况下对数据进行洗牌。 您甚至可以创建一个用于洗牌的函数:

BF = 2
SEED = 2
def shuffling(dataset, bf, seed_number):
   return dataset.shuffle(buffer_size=bf, seed=seed_number)

x_train, y_train = shuffling(x_train, BF, SEED), shuffling(y_train, BF, SEED)
dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))

当我尝试合并两个数据集 dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)) 时,出现以下错误:ValueError: Slicing dataset elements is not supported for rank 0 - Mykola Zotko

2
如果您只想以相同的方式随机两个数组,可以使用以下方法:
import tensorflow as tf

# Assuming X and y are initially NumPy arrays
X = tf.convert_to_tensor(X)
y = tf.convert_to_tensor(y)
# Make random permutation
perm = tf.random.shuffle(tf.range(tf.shape(X)[0]))
# Reorder according to permutation
X = tf.gather(X, perm, axis=0)
y = tf.gather(y, perm, axis=0)

然而,您可以考虑使用tf.data.Dataset,它已经提供了一个shuffle方法。

import tensorflow as tf

# You may use a placeholder if in graph mode
# (see https://www.tensorflow.org/guide/datasets#consuming_numpy_arrays)
ds = tf.data.Dataset.from_tensor_slices((X, y))
# Shuffle with some buffer size (len(X) will use a buffer as big as X)
ds = ds.shuffle(buffer_size=len(X))

1
我们如何从数据集对象中检索形状初始化的混洗张量X和Y?谢谢。 - Dhiraj Dhakal

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接