“nlargest” 返回奇怪的结果。

3

我正在尝试使用“largest”获取前n个结果,但是行为有点奇怪。如果有人能帮助我理解为什么会这样,那就太好了。

filter = pd.DataFrame([['user1','item2',2,1],
                   ['user1','item1',2,0.666667],
                   ['user1','item3',2,0.500000]],
                  columns=['user_id','item_id','num_transactions','RCP'])

sort_RCP_df = (
        filter.set_index("item_id")
        .groupby(["user_id"])["RCP"]
        .nlargest(2)
        .reset_index()
)
print(sort_RCP_df)

user_id item_id RCP
user1   item2   1.000000
user1   item1   0.666667

如果我保留 nlargest(2),那么会得到正确的输出,但如果我将其更改为 3,那么只会得到 item_id 和 RCP 两列。
filter = pd.DataFrame([['user1','item2',2,1],
                   ['user1','item1',2,0.666667],
                   ['user1','item3',2,0.500000]],
                  columns=['user_id','item_id','num_transactions','RCP'])

sort_RCP_df = (
        filter.set_index("item_id")
        .groupby(["user_id"])["RCP"]
        .nlargest(3)
        .reset_index()
)
print(sort_RCP_df)

item_id RCP
item2   1.000000
item1   0.666667
item3   0.500000

为什么使用 nlargest = 3 后列 'user_id' 没有出现?
如果这是期望的行为,是否有办法让 'user_id' 也成为输出的一部分?

有趣。看着代码,你确实告诉它只提取“RCP”列,所以在那之后,你应该只有一个包含“RCP”和索引(“item_id”)的系列。当你要求最大的2个时,我不确定为什么会得到“user_id”。这对我来说是个谜。 - Tim Roberts
2
如果您请求的项目数量少于数据框中的行数,则返回的Series将使用由“groupby”指定的索引。如果您请求的项目数量大于或等于df中的行数,则跳过“groupby”索引。 - user17242583
作为一种解决方法,您可以将"user_id"也分配给索引。所以:df.set_index(["item_id", "user_id"]).groupby("user_id")["RCP"].nlargest(3).reset_index() - not_speshal
@not_speshal 这段代码在我的环境中无法工作。我得到了一个“ValueError:无法插入user_id,因为已存在”的错误信息。 - Prince Modi
你使用的是哪个版本的pandas?它可以在pandas 1.3.4上运行。 - not_speshal
1个回答

1
文件中的注释明确指出了性能方面的考虑,因此暗示了问题的原因:

相对于 Series 对象的大小而言,对于小的 n ,比 .sort_values(ascending=False).head(n) 更快。

如果你深入研究代码,会发现 Series.nlargest/Series.nsmallest 由 pandas/core/algorithms 中的 SelectNSeries 类处理。这个类的行为取决于 n 相对于 Series 长度的不同情况:
# slow method
if n >= len(self.obj):
    ascending = method == "nsmallest"
    return dropped.sort_values(ascending=ascending).head(n)

# fast method
arr, new_dtype = _ensure_data(dropped.values)
if method == "nlargest":
    arr = -arr
    if is_integer_dtype(new_dtype):
        # GH 21426: ensure reverse ordering at boundaries
        arr -= 1

...

这里的关键点是,当 n >= Series 的长度 时,调用不使用正常算法来计算最大/最小值,而是使用 sort_values + head 来计算。如果我们用这个逻辑替换你的 nlargest 调用,就可以手动匹配你的输出。
sort_RCP_df = (
        filter.set_index("item_id")
        .groupby(["user_id"])["RCP"]
        .apply(lambda s: s.sort_values(ascending=False).head(2))
        .reset_index()
)

#  user_id item_id       RCP
#0   user1   item2  1.000000
#1   user1   item1  0.666667

那么我们该如何更改代码以在 n = 序列长度时获得输出呢? - Prince Modi

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接