如何从 PySpark DataFrame 中随机取一行?

51

如何从 PySpark DataFrame 中获取随机一行?我只看到了 sample() 方法,该方法需要一个分数作为参数。将该分数设置为 1/numberOfRows 将会导致随机结果,有时候我甚至不会得到任何行。

RDD 上有一个 takeSample() 方法,它带有一个参数,用于指定样本包含的元素数量。我理解这可能会很慢,因为您必须计算每个分区的数量,但是否有一种方法可以在 DataFrame 上实现这样的操作呢?

3个回答

87
你可以在RDD上简单地调用takeSample
df = sqlContext.createDataFrame(
    [(1, "a"), (2, "b"), (3, "c"), (4, "d")], ("k", "v"))
df.rdd.takeSample(False, 1, seed=0)
## [Row(k=3, v='c')]

如果您不想收集,只需提高分数并设置限制即可:

df.sample(False, 0.1, seed=0).limit(1)

不要传递seed,这样每次都会得到不同的DataFrame。


2
有没有获取随机值的方法。在上面的情况下,每次运行查询时都会生成相同的数据框架。 - Nikhil Baby
1
不错的提示,@LateCoder!(在Spark 2.3.1上,仅将种子保留为None似乎只适用于df.rdd.takeSample,而不适用于df.sample。) - Quentin Pradet
1
为什么有人不想要使用 collect - ijoseph
5
因为 collect 会将数据返回给驱动程序,可能导致数据无法适应驱动程序的内存限制。 - ijoseph
1
我认为第二个样本 -> 限制解决方案并不完全随机。sample() 部分是好的和随机的,但在取限制之前,结果似乎有些排序。如果你做 limit(10) 而不是 1,并且你的分数太大,这一点尤其明显。结果可能看起来相似。 - Paul Fornia

16

样本的不同类型

使用有放回或无放回抽样随机抽取百分之几的数据

import pyspark.sql.functions as F
#Randomly sample 50% of the data without replacement
sample1 = df.sample(False, 0.5, seed=0)

#Randomly sample 50% of the data with replacement
sample1 = df.sample(True, 0.5, seed=0)

#Take another sample exlcuding records from previous sample using Anti Join
sample2 = df.join(sample1, on='ID', how='left_anti').sample(False, 0.5, seed=0)

#Take another sample exlcuding records from previous sample using Where
sample1_ids = [row['ID'] for row in sample1.ID]
sample2 = df.where(~F.col('ID').isin(sample1_ids)).sample(False, 0.5, seed=0)

#Generate a startfied sample of the data across column(s)
#Sampling is probabilistic and thus cannot guarantee an exact number of rows
fractions = {
        'NJ': 0.5, #Take about 50% of records where state = NJ
    'NY': 0.25, #Take about 25% of records where state = NY
    'VA': 0.1, #Take about 10% of records where state = VA
}
stratified_sample = df.sampleBy(F.col('state'), fractions, seed=0)

0
这里有一个使用 Pandas DataFrame.Sample 方法的替代方案。它使用了 Spark 3.0.0 中提供的 applyInPandas 方法来分发组,这使您可以选择每个组的确切行数。
我已经向函数中添加了 argskwargs,以便您可以访问 DataFrame.Sample 的其他参数。
def sample_n_per_group(n, *args, **kwargs):
    def sample_per_group(pdf):
        return pdf.sample(n, *args, **kwargs)
    return sample_per_group

df = spark.createDataFrame(
    [
        (1, 1.0), 
        (1, 2.0), 
        (2, 3.0), 
        (2, 5.0), 
        (2, 10.0)
    ],
    ("id", "v")
)

(df.groupBy("id")
   .applyInPandas(
        sample_n_per_group(1, random_state=2), 
        schema=df.schema
   )
)

为了了解对于非常大的组的限制,请查看文档:

此函数需要进行完整的洗牌。一个组的所有数据都将被加载到内存中,因此如果数据失衡且某些组太大而无法适应内存,则用户应该注意潜在的OOM风险。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接