有人知道如何将Tensorflow中使用数据集API(tf.data.Dataset)创建的数据集拆分为测试集和训练集吗?
有人知道如何将Tensorflow中使用数据集API(tf.data.Dataset)创建的数据集拆分为测试集和训练集吗?
假设您有一个名为all_dataset
的变量,其类型为tf.data.Dataset
:
test_dataset = all_dataset.take(1000)
train_dataset = all_dataset.skip(1000)
现在测试数据集有前1000个元素,剩下的用于训练。
all_dataset.shuffle()
可实现乱序拆分。在回答中可能需要添加代码注释,例如:# all_dataset = all_dataset.shuffle() # 如果您想要乱序拆分
- Christian SteinmeyerDataset.take()
和 Dataset.skip()
:train_size = int(0.7 * DATASET_SIZE)
val_size = int(0.15 * DATASET_SIZE)
test_size = int(0.15 * DATASET_SIZE)
full_dataset = tf.data.TFRecordDataset(FLAGS.input_file)
full_dataset = full_dataset.shuffle()
train_dataset = full_dataset.take(train_size)
test_dataset = full_dataset.skip(train_size)
val_dataset = test_dataset.skip(val_size)
test_dataset = test_dataset.take(test_size)
为了更具普适性,我提供了一个70/15/15的训练/验证/测试集划分的示例,但是如果你不需要测试集或验证集,可以忽略最后两行。
Take:
创建一个包含最多count个元素的数据集。
Skip:
创建一个跳过此数据集中count个元素的数据集。
你也可以查看Dataset.shard()
:
创建一个仅包含此数据集1/num_shards的数据集。
免责声明:我在回答这个问题后偶然发现了这个问题,所以我想分享给大家。
reshuffle_each_iteration=False
,否则元素可能会在训练、测试和验证集中重复出现。 - xdolalist_files
时,应该使用shuffle=False
,然后使用.shuffle
和reshuffle_each_iteration=False
进行洗牌。 - Zaccharie Ramzi这里大部分答案使用take()
和skip()
,需要事先知道数据集的大小。这并非总是可能的,或者很难/需要大量计算。
相反,您可以将数据集切片,以便每N条记录中的1个成为验证记录。
为了实现这一点,让我们从一个简单的0-9数据集开始:
dataset = tf.data.Dataset.range(10)
# [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
现在我们以一个示例来说明,我们将对其进行分割,以获得3/1的训练/验证比例。这意味着3个记录将用于训练,然后1个记录用于验证,然后重复此过程。
split = 3
dataset_train = dataset.window(split, split + 1).flat_map(lambda ds: ds)
# [0, 1, 2, 4, 5, 6, 8, 9]
dataset_validation = dataset.skip(split).window(1, split + 1).flat_map(lambda ds: ds)
# [3, 7]
因此,第一个dataset.window(split, split + 1)
表示获取split
号 (3)元素,然后前进split + 1
个元素,并重复执行。加上+1
实际上跳过了我们将在验证数据集中使用的1个元素。
flat_map(lambda ds: ds)
是因为window()
返回批处理结果,而我们不需要这样,所以我们要将其展开。
对于验证数据,我们首先skip(split)
跳过在第一个训练窗口中抓取的前split
个 (3) 元素,因此我们的迭代从第四个元素开始。然后,window(1, split+1)
获取一个元素,在split+1
(4)个元素处前进并重复。
关于嵌套数据集的说明:
上面的示例适用于简单的数据集,但如果数据集嵌套,则flat_map()
会生成错误。为解决这个问题,您可以使用更复杂的版本来替换flat_map()
,以处理简单和嵌套的数据集:
.flat_map(lambda *ds: ds[0] if len(ds) == 1 else tf.data.Dataset.zip(ds))
dataset.prefetch()
,它可以在后台读取数据同时执行其他任务。区别在于节省了初始启动时间。 - phemmer@ted的答案可能会导致一些重叠。尝试这个。
train_ds_size = int(0.64 * full_ds_size)
valid_ds_size = int(0.16 * full_ds_size)
train_ds = full_ds.take(train_ds_size)
remaining = full_ds.skip(train_ds_size)
valid_ds = remaining.take(valid_ds_size)
test_ds = remaining.skip(valid_ds_size)
使用以下代码进行测试。
tf.enable_eager_execution()
dataset = tf.data.Dataset.range(100)
train_size = 20
valid_size = 30
test_size = 50
train = dataset.take(train_size)
remaining = dataset.skip(train_size)
valid = remaining.take(valid_size)
test = remaining.skip(valid_size)
for i in train:
print(i)
for i in valid:
print(i)
for i in test:
print(i)
full_ds_size
是什么,但没有人解释如何找到它。 - Bersandataset = dataset.shuffle() # optional
trainset = dataset.shard(2, 0)
testset = dataset.shard(2, 1)
See: https://www.tensorflow.org/api_docs/python/tf/data/Dataset#shard
tf.keras.utils.split_dataset function
的函数,详见rc3 release notes:
新增了
tf.keras.utils.split_dataset
工具函数,用于将一个Dataset
对象或者一个数组列表/元组分割成两个Dataset
对象(例如:训练集和测试集)。
split_dataset
函数可以使 image_dataset_from_directory
的洗牌重新迭代稳定,从而在后面正确地产生 Model.predict
的有序结果。请参见 https://discuss.tensorflow.org/t/tf-data-dataset-varies-at-re-iteration-manual-reset-possible/11695 - Robert Pollak目前TensorFlow中没有相关工具。
你可以使用sklearn.model_selection.train_test_split
生成训练集、验证集和测试集,然后分别创建相应的tf.data.Dataset
。
tf.strings.to_hash_bucket_fast
。然后可以通过按桶进行过滤来将数据集拆分为两个部分。如果将数据拆分为五个桶,则假定拆分是均匀的,就会得到80-20的拆分比例。filename
的字典。我们基于此键将数据分成了五个桶。使用这个add_fold
函数,我们在字典中添加了"fold"
键。def add_fold(buckets: int):
def add_(sample, label):
fold = tf.strings.to_hash_bucket(sample["filename"], num_buckets=buckets)
return {**sample, "fold": fold}, label
return add_
dataset = dataset.map(add_fold(buckets=5))
Dataset.filter
将数据集分成两个不相交的数据集:def pick_fold(fold: int):
def filter_fn(sample, _):
return tf.math.equal(sample["fold"], fold)
return filter_fn
def skip_fold(fold: int):
def filter_fn(sample, _):
return tf.math.not_equal(sample["fold"], fold)
return filter_fn
train_dataset = dataset.filter(skip_fold(0))
val_dataset = dataset.filter(pick_fold(0))
用于哈希的密钥应该捕捉数据集中的相关性。例如,如果由同一人收集的样本是相关的,并且您希望所有具有相同收集器的样本最终都进入同一个桶(和同一个拆分),则应将收集器名称或ID用作哈希列。
当然,您可以跳过dataset.map
部分,在一个filter
函数中完成哈希和过滤。以下是完整示例:
dataset = tf.data.Dataset.from_tensor_slices([f"value-{i}" for i in range(10000)])
def to_bucket(sample):
return tf.strings.to_hash_bucket_fast(sample, 5)
def filter_train_fn(sample):
return tf.math.not_equal(to_bucket(sample), 0)
def filter_val_fn(sample):
return tf.math.logical_not(filter_train_fn(sample))
train_ds = dataset.filter(filter_train_fn)
val_ds = dataset.filter(filter_val_fn)
print(f"Length of training set: {len(list(train_ds.as_numpy_iterator()))}")
print(f"Length of validation set: {len(list(val_ds.as_numpy_iterator()))}")
这将打印:
Length of training set: 7995
Length of validation set: 2005
shuffle+take
和shuffle+skip
。因此,一些高分答案会导致信息泄漏。以下是正确的方法,即在训练和测试数据集中重复和种子洗牌。import tensorflow as tf
def gen_data():
return iter(range(10))
ds = tf.data.Dataset.from_generator(
gen_data,
output_signature=tf.TensorSpec(shape=(),dtype=tf.int32,name="element"))
SEED = 42 # NOTE: with no seed, you overlap train and test!
ds_train = ds.shuffle(100,seed=SEED).take(8).shuffle(100)
ds_test = ds.shuffle(100,seed=SEED).skip(8)
A = set(ds_train.as_numpy_iterator())
B = set(ds_test.as_numpy_iterator())
assert A.intersection(B)==set()
print(list(A))
print(list(B))
from typing import Tuple
import tensorflow as tf
def split_dataset(dataset: tf.data.Dataset,
dataset_size: int,
train_ratio: float,
validation_ratio: float) -> Tuple[tf.data.Dataset, tf.data.Dataset, tf.data.Dataset]:
assert (train_ratio + validation_ratio) < 1
train_count = int(dataset_size * train_ratio)
validation_count = int(dataset_size * validation_ratio)
test_count = dataset_size - (train_count + validation_count)
dataset = dataset.shuffle(dataset_size)
train_dataset = dataset.take(train_count)
validation_dataset = dataset.skip(train_count).take(validation_count)
test_dataset = dataset.skip(validation_count + train_count).take(test_count)
return train_dataset, validation_dataset, test_dataset
例子:
size_of_ds = 1001
train_ratio = 0.6
val_ratio = 0.2
ds = tf.data.Dataset.from_tensor_slices(list(range(size_of_ds)))
train_ds, val_ds, test_ds = split_dataset(ds, size_of_ds, train_ratio, val_ratio)
take()
、skip()
和shard()
都有各自的问题。我刚刚在这里发布了我的答案。希望它能更好地回答你的问题。 - Nick Lee