在Tensorflow中获取数据集的长度

10
source_dataset = tf.data.TextLineDataset('primary.csv')
target_dataset = tf.data.TextLineDataset('secondary.csv')
dataset = tf.data.Dataset.zip((source_dataset, target_dataset))
dataset = dataset.shard(10000, 0)
dataset = dataset.map(lambda source, target: (tf.string_to_number(tf.string_split([source], delimiter=',').values, tf.int32),
                                              tf.string_to_number(tf.string_split([target], delimiter=',').values, tf.int32)))
dataset = dataset.map(lambda source, target: (source, tf.concat(([start_token], target), axis=0), tf.concat((target, [end_token]), axis=0)))
dataset = dataset.map(lambda source, target_in, target_out: (source, tf.size(source), target_in, target_out, tf.size(target_in)))

dataset = dataset.shuffle(NUM_SAMPLES)  #This is the important line of code

我想完全打乱我的数据集,但是shuffle()需要指定抽取样本的数量,而且tf.Size()不适用于tf.data.Dataset

我该如何正确地进行打乱操作?


1
它应该与您较小的CSV文件的大小相同。我不知道Tensorflow中是否有一个返回数据集长度的函数或属性。 - Lescurel
1
文档中得知:生成的数据集元素数量与最小数据集的大小相同 - Lescurel
zip() 的工作方式相同;当最短的对象引发 StopIteration 时,迭代结束。 - markemus
2个回答

2

我正在使用tf.data.FixedLengthRecordDataset(),遇到了类似的问题。 在我的情况下,我想要只取原始数据的一定百分比。 由于我知道所有记录都有固定的长度,所以对我来说一个解决方法是:

totalBytes = sum([os.path.getsize(os.path.join(filepath, filename)) for filename in os.listdir(filepath)])
numRecordsToTake = tf.cast(0.01 * percentage * totalBytes / bytesPerRecord, tf.int64)
dataset = tf.data.FixedLengthRecordDataset(filenames, recordBytes).take(numRecordsToTake)
在您的情况下,我的建议是直接在Python中计算“primary.csv”和“secondary.csv”中记录的数量。或者,我认为为了您的目的,设置buffer_size参数不需要实际计算文件数。根据“关于缓冲区大小含义的被接受答案”的回答,一个大于数据集元素数量的数字将确保整个数据集的均匀洗牌。因此,只需输入一个非常大的数字(您认为将超过数据集大小的数字)即可。

你是如何使用shuffle和split来处理你的数据集的?(使用totalBytes / bytesPerRecord) - JeeyCi

1

从TensorFlow 2开始,可以通过cardinality()函数轻松获取数据集的长度。

dataset = tf.data.Dataset.range(42)
#both print 42 
dataset_length_v1 = tf.data.experimental.cardinality(dataset).numpy())
dataset_length_v2 = dataset.cardinality().numpy()

注意:当使用谓词(如filter)时,长度的返回值可能为-2。您可以在此处查看解释,否则请阅读以下段落:

如果您使用了filter谓词,则基数可能会返回值-2,因此未知;如果您确实在数据集上使用了filter谓词,请确保已以另一种方式计算出数据集的长度(例如,在对其应用.from_tensor_slices()之前计算Pandas DataFrame的长度)。


2
我已经尝试了两个数据集,结果都是-2。 - Toby
是的,这里是说明原因 - Timbus Calin

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,