如何将由Tensorflow数据集API创建的数据集分为训练集和测试集?

72

有人知道如何将Tensorflow中使用数据集API(tf.data.Dataset)创建的数据集拆分为测试集和训练集吗?


1
take()skip()shard()都有各自的问题。我刚刚在这里发布了我的答案。希望它能更好地回答你的问题。 - Nick Lee
使用Keras - model.fit(dataset,.., validation.split=0.7, ...) 查看其所有可能的参数。 - JeeyCi
11个回答

95

假设您有一个名为all_dataset的变量,其类型为tf.data.Dataset

test_dataset = all_dataset.take(1000) 
train_dataset = all_dataset.skip(1000)

现在测试数据集有前1000个元素,剩下的用于训练。


4
正如ted的回答中提到的,添加all_dataset.shuffle()可实现乱序拆分。在回答中可能需要添加代码注释,例如:# all_dataset = all_dataset.shuffle() # 如果您想要乱序拆分 - Christian Steinmeyer
1
TensorFlow 2.10.0将拥有一个用于拆分的实用函数,参见我的答案:https://dev59.com/klYM5IYBdhLWcg3w5zUg#73591823 - Robert Pollak
take和skip返回的TfTakeDatasets / SkipDatasets比TfDatasets的功能要少。有人知道如何将它们映射到tfDatasets或将其拆分为训练测试拆分并获取TfDataset对象吗? - Nikos H.

59
你可以使用 Dataset.take()Dataset.skip()
train_size = int(0.7 * DATASET_SIZE)
val_size = int(0.15 * DATASET_SIZE)
test_size = int(0.15 * DATASET_SIZE)

full_dataset = tf.data.TFRecordDataset(FLAGS.input_file)
full_dataset = full_dataset.shuffle()
train_dataset = full_dataset.take(train_size)
test_dataset = full_dataset.skip(train_size)
val_dataset = test_dataset.skip(val_size)
test_dataset = test_dataset.take(test_size)

为了更具普适性,我提供了一个70/15/15的训练/验证/测试集划分的示例,但是如果你不需要测试集或验证集,可以忽略最后两行。

Take

创建一个包含最多count个元素的数据集。

Skip

创建一个跳过此数据集中count个元素的数据集。

你也可以查看Dataset.shard()

创建一个仅包含此数据集1/num_shards的数据集。


免责声明:我在回答这个问题后偶然发现了这个问题,所以我想分享给大家。


4
非常感谢@ted!有没有一种分层的方式来分割数据集?或者,我们如何在训练/验证/测试划分之后对类别比例(假设是二元问题)有一个大致的了解?非常感谢您提前的帮助! - Tommaso Di Noto
6
这导致我的训练、验证和测试数据集之间存在重叠。这样做是否正常或者无关紧要?我认为在模型训练中使用验证和测试数据集不是一个好主意。 - bw0248
6
@c_student 我曾经遇到过同样的问题,后来我发现了自己的失误:当你进行洗牌操作时,请使用选项 reshuffle_each_iteration=False,否则元素可能会在训练、测试和验证集中重复出现。 - xdola
2
非常正确,@xdola。特别是在使用list_files时,应该使用shuffle=False,然后使用.shufflereshuffle_each_iteration=False进行洗牌。 - Zaccharie Ramzi
1
通过这个答案,我们得到了一个TakeDataSet,它并没有与Dataset完全相同的属性和方法。 - Paul
显示剩余2条评论

32

这里大部分答案使用take()skip(),需要事先知道数据集的大小。这并非总是可能的,或者很难/需要大量计算。

相反,您可以将数据集切片,以便每N条记录中的1个成为验证记录。

为了实现这一点,让我们从一个简单的0-9数据集开始:

dataset = tf.data.Dataset.range(10)
# [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

现在我们以一个示例来说明,我们将对其进行分割,以获得3/1的训练/验证比例。这意味着3个记录将用于训练,然后1个记录用于验证,然后重复此过程。

split = 3
dataset_train = dataset.window(split, split + 1).flat_map(lambda ds: ds)
# [0, 1, 2, 4, 5, 6, 8, 9]
dataset_validation = dataset.skip(split).window(1, split + 1).flat_map(lambda ds: ds)
# [3, 7]

因此,第一个dataset.window(split, split + 1)表示获取split(3)元素,然后前进split + 1个元素,并重复执行。加上+1实际上跳过了我们将在验证数据集中使用的1个元素。
flat_map(lambda ds: ds)是因为window()返回批处理结果,而我们不需要这样,所以我们要将其展开。

对于验证数据,我们首先skip(split)跳过在第一个训练窗口中抓取的前split(3) 元素,因此我们的迭代从第四个元素开始。然后,window(1, split+1)获取一个元素,在split+1 (4)个元素处前进并重复。

 

关于嵌套数据集的说明:
上面的示例适用于简单的数据集,但如果数据集嵌套,则flat_map()会生成错误。为解决这个问题,您可以使用更复杂的版本来替换flat_map(),以处理简单和嵌套的数据集:

.flat_map(lambda *ds: ds[0] if len(ds) == 1 else tf.data.Dataset.zip(ds))

2
如果您有一个包含1000条记录的数据集,并且想要10%用于验证,那么在发出单个验证记录之前,您必须跳过前900条记录。使用这种解决方案,只需要跳过9条记录。它最终会跳过相同数量的记录,但是如果您使用dataset.prefetch(),它可以在后台读取数据同时执行其他任务。区别在于节省了初始启动时间。 - phemmer
5
你可能需要将“without knowing the dataset size beforehand”设置为加粗或者像标题一样突出显示,因为这一点非常重要。这应该是被接受的答案,因为它符合“tf.data.Dataset”将数据视为无限流的前提。 - Frederik Bode
当我尝试使用这种方法时,遇到的一个问题是RAM消耗比@ted描述的方法要高得多。消耗高得多,以至于我根本无法在我的机器上运行它。也许我做错了什么,但如果我不知道数据集的大小并且有数据无法适应内存,那么可行的方法是什么? - witsyke
@witsyke 这个解决方案本身不会导致与其他解决方案相比内存使用量的显著差异。如果您遇到这样的情况,那么原因可能是其余部分与之交互的方式。我建议在新问题中发布完整细节。我在数百GB的数据集上使用此方法而没有任何问题。 - phemmer
不确定这是否是在StackOverflow上的良好行为,但这是我为参考而创建的问题:https://stackoverflow.com/questions/68274353/ram-issues-when-trying-to-create-tensorflow-dataset-pippline-that-loads-from-mul如果不是,请告诉我,我会删除我的评论。 - witsyke
显示剩余2条评论

8

@ted的答案可能会导致一些重叠。尝试这个。

train_ds_size = int(0.64 * full_ds_size)
valid_ds_size = int(0.16 * full_ds_size)

train_ds = full_ds.take(train_ds_size)
remaining = full_ds.skip(train_ds_size)  
valid_ds = remaining.take(valid_ds_size)
test_ds = remaining.skip(valid_ds_size)

使用以下代码进行测试。

tf.enable_eager_execution()

dataset = tf.data.Dataset.range(100)

train_size = 20
valid_size = 30
test_size = 50

train = dataset.take(train_size)
remaining = dataset.skip(train_size)
valid = remaining.take(valid_size)
test = remaining.skip(valid_size)

for i in train:
    print(i)

for i in valid:
    print(i)

for i in test:
    print(i)

7
我喜欢每个人都默认你知道full_ds_size是什么,但没有人解释如何找到它。 - Bersan
1
@Bersan len(list(dataset)) 是最简单直接的方法。 https://dev59.com/2lUL5IYBdhLWcg3wK1YU ...但是...我的理解是,数据集可能非常大(可能无法适应内存),因此对它们进行迭代可能需要很长时间。 最好根据对数据集的外部了解来确定数据集的大小。 - BobtheMagicMoose

4

6
"shard is deprecated" 可以翻译为 "shard 已被弃用"。 - vgoklani
5
@vgoklani你确定吗?我没有看到任何说明它已被弃用。 翻译:你确定吗?我没看到任何标明它已被废弃的信息。 - BobtheMagicMoose

4
即将发布的TensorFlow 2.10.0版本将会有一个名为tf.keras.utils.split_dataset function的函数,详见rc3 release notes

新增了tf.keras.utils.split_dataset工具函数,用于将一个Dataset对象或者一个数组列表/元组分割成两个Dataset对象(例如:训练集和测试集)。


顺便说一下,我发现使用这个单独的 split_dataset 函数可以使 image_dataset_from_directory 的洗牌重新迭代稳定,从而在后面正确地产生 Model.predict 的有序结果。请参见 https://discuss.tensorflow.org/t/tf-data-dataset-varies-at-re-iteration-manual-reset-possible/11695 - Robert Pollak

4

目前TensorFlow中没有相关工具。
你可以使用sklearn.model_selection.train_test_split生成训练集、验证集和测试集,然后分别创建相应的tf.data.Dataset


6
sklearn要求数据适合内存,TF Data则不需要。 - Denziloe

0
一种将数据集拆分为两个部分的可靠方法是首先通过确定性映射将数据集中的每个项目映射到一个桶中,例如使用 tf.strings.to_hash_bucket_fast。然后可以通过按桶进行过滤来将数据集拆分为两个部分。如果将数据拆分为五个桶,则假定拆分是均匀的,就会得到80-20的拆分比例。
例如,假设您的数据集包含具有键值filename的字典。我们基于此键将数据分成了五个桶。使用这个add_fold函数,我们在字典中添加了"fold"键。
def add_fold(buckets: int):
    def add_(sample, label):
        fold = tf.strings.to_hash_bucket(sample["filename"], num_buckets=buckets)
        return {**sample, "fold": fold}, label

    return add_

dataset = dataset.map(add_fold(buckets=5))

现在我们可以使用Dataset.filter将数据集分成两个不相交的数据集:
def pick_fold(fold: int):
    def filter_fn(sample, _):
        return tf.math.equal(sample["fold"], fold)

    return filter_fn


def skip_fold(fold: int):
    def filter_fn(sample, _):
        return tf.math.not_equal(sample["fold"], fold)

    return filter_fn

train_dataset = dataset.filter(skip_fold(0))
val_dataset = dataset.filter(pick_fold(0))

用于哈希的密钥应该捕捉数据集中的相关性。例如,如果由同一人收集的样本是相关的,并且您希望所有具有相同收集器的样本最终都进入同一个桶(和同一个拆分),则应将收集器名称或ID用作哈希列。

当然,您可以跳过dataset.map部分,在一个filter函数中完成哈希和过滤。以下是完整示例:

dataset = tf.data.Dataset.from_tensor_slices([f"value-{i}" for i in range(10000)])

def to_bucket(sample):
    return tf.strings.to_hash_bucket_fast(sample, 5)

def filter_train_fn(sample):
    return tf.math.not_equal(to_bucket(sample), 0)

def filter_val_fn(sample):
    return tf.math.logical_not(filter_train_fn(sample))

train_ds = dataset.filter(filter_train_fn)
val_ds = dataset.filter(filter_val_fn)

print(f"Length of training set: {len(list(train_ds.as_numpy_iterator()))}")
print(f"Length of validation set: {len(list(val_ds.as_numpy_iterator()))}")

这将打印:

Length of training set: 7995
Length of validation set: 2005

0
小心懒惰评估,它会产生两个重叠的流水线shuffle+takeshuffle+skip。因此,一些高分答案会导致信息泄漏。以下是正确的方法,即在训练和测试数据集中重复和种子洗牌。
import tensorflow as tf

def gen_data():
    return iter(range(10))

ds = tf.data.Dataset.from_generator(
    gen_data, 
    output_signature=tf.TensorSpec(shape=(),dtype=tf.int32,name="element"))

SEED = 42 # NOTE: with no seed, you overlap train and test!

ds_train = ds.shuffle(100,seed=SEED).take(8).shuffle(100)
ds_test  = ds.shuffle(100,seed=SEED).skip(8)

A = set(ds_train.as_numpy_iterator())
B = set(ds_test.as_numpy_iterator())

assert A.intersection(B)==set()

print(list(A))
print(list(B))

注意:这适用于任何确定顺序的迭代器。

0
如果数据集的大小已知:
from typing import Tuple
import tensorflow as tf

def split_dataset(dataset: tf.data.Dataset, 
                  dataset_size: int, 
                  train_ratio: float, 
                  validation_ratio: float) -> Tuple[tf.data.Dataset, tf.data.Dataset, tf.data.Dataset]:
    assert (train_ratio + validation_ratio) < 1

    train_count = int(dataset_size * train_ratio)
    validation_count = int(dataset_size * validation_ratio)
    test_count = dataset_size - (train_count + validation_count)

    dataset = dataset.shuffle(dataset_size)

    train_dataset = dataset.take(train_count)
    validation_dataset = dataset.skip(train_count).take(validation_count)
    test_dataset = dataset.skip(validation_count + train_count).take(test_count)

    return train_dataset, validation_dataset, test_dataset

例子:

size_of_ds = 1001
train_ratio = 0.6
val_ratio = 0.2

ds = tf.data.Dataset.from_tensor_slices(list(range(size_of_ds)))
train_ds, val_ds, test_ds = split_dataset(ds, size_of_ds, train_ratio, val_ratio)

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接