我有一个由超过60亿行数据组成的Spark RDD,我想用它来训练深度学习模型,使用train_on_batch。我无法将所有行都放入内存中,因此我想每次获取大约10K行数据,然后批处理成64或128个数据块(取决于模型大小)。我目前正在使用rdd.sample(),但我认为这并不能保证我能够获取到所有行。是否有更好的方法来分区数据,以使其更易于管理,从而编写生成器函数以获取批处理数据?我的代码如下:
data_df = spark.read.parquet(PARQUET_FILE)
print(f'RDD Count: {data_df.count()}') # 6B+
data_sample = data_df.sample(True, 0.0000015).take(6400)
sample_df = data_sample.toPandas()
def get_batch():
for row in sample_df.itertuples():
# TODO: put together a batch size of BATCH_SIZE
yield row
for i in range(10):
print(next(get_batch()))