在Pandas DataFrame中,根据每行的条件计算分组的滚动均值。

3
我正在尝试创建一个名为rolling_median的列,具体如下:
对于每一行: - 过滤DataFrame,只保留与当前行的Date_A之前且Category相同的行 - 对这个过滤后的DataFrame按Date_B进行排序 - 使用最后N个符合过滤条件的行来计算Value的中位数。例如,如果N=2,则使用最后2行来估计当前行的平均值
使用iterrows方法可以实现,但不具有可扩展性。 您有没有想过以更"向量化"的方式来降低复杂度?
以下是生成类似DataFrame的代码片段:
import pandas as pd
import numpy as np

# Sample DataFrame (replace this with your actual DataFrame)
data = {
    'Category': ['A', 'A', 'A', 'A', 'A', 'B', 'B', 'B', 'B', 'B'],
    'Date_A': ['2023-07-08', '2023-07-09', '2023-07-11', '2023-07-12', '2023-07-13', '2023-07-08', '2023-07-09', '2023-07-11', '2023-07-12', '2023-07-13'],
    'Date_B': ['2023-07-08', '2023-07-10', '2023-07-12', '2023-07-12', '2023-07-13', '2023-07-08', '2023-07-10', '2023-07-12', '2023-07-12', '2023-07-13'],
    'Value': [10, 15, 20, 25, 30, 35, 40, 45, 50, 55]
}


df = pd.DataFrame(data)

# Convert 'Date_A' and 'Date_B' columns to datetime type
df['Date_A'] = pd.to_datetime(df['Date_A'])
df['Date_B'] = pd.to_datetime(df['Date_B'])
df['rolling_mean'] = np.nan

# Last 2 values
N=2

for category in df.Category.unique():
    df_cat = df[df.Category==category]
    for idx, row in df_cat.iterrows():
        rm = df_cat[df_cat.Date_B < row.Date_A][:N].Value.mean()
        df.at[idx, 'rolling_mean'] = rm

df

  Category     Date_A     Date_B  Value  rolling_mean
0        A 2023-07-08 2023-07-08     10           NaN
1        A 2023-07-09 2023-07-10     15          10.0
2        A 2023-07-11 2023-07-12     20          12.5
3        A 2023-07-12 2023-07-12     25          12.5
4        A 2023-07-13 2023-07-13     30          12.5
5        B 2023-07-08 2023-07-08     35           NaN
6        B 2023-07-09 2023-07-10     40          35.0
7        B 2023-07-11 2023-07-12     45          37.5
8        B 2023-07-12 2023-07-12     50          37.5
9        B 2023-07-13 2023-07-13     55          37.5

可以提供精确的预期输出吗?别忘了定义 N - mozway
你能提供准确的预期输出吗?别忘了定义N - mozway
你能提供具体的预期输出吗?别忘了定义N - undefined
我添加了N的含义,并且提供了一个更详细的例子,谢谢! - JBSH
我添加了N的含义,并且提供了一个更详细的例子,谢谢! - JBSH
2个回答

1
可能不是最高效的方法,但似乎比双重循环要快一点。
import pandas as pd
import numpy as np

# Sample DataFrame (replace this with your actual DataFrame)
data = {
    'Category': ['A', 'A', 'A', 'A', 'A', 'B', 'B', 'B', 'B', 'B'],
    'Date_A': ['2023-07-08', '2023-07-09', '2023-07-11', '2023-07-12', '2023-07-13', '2023-07-08', '2023-07-09', '2023-07-11', '2023-07-12', '2023-07-13'],
    'Date_B': ['2023-07-08', '2023-07-10', '2023-07-12', '2023-07-12', '2023-07-13', '2023-07-08', '2023-07-10', '2023-07-12', '2023-07-12', '2023-07-13'],
    'Value': [10, 15, 20, 25, 30, 35, 40, 45, 50, 55]
}


df = pd.DataFrame(data)

# Convert 'Date_A' and 'Date_B' columns to datetime type
df['Date_A'] = pd.to_datetime(df['Date_A'])
df['Date_B'] = pd.to_datetime(df['Date_B'])
df['rolling_mean'] = np.nan

# Last 2 values
N=2

df["rolling_mean"] = df.groupby("Category").apply(
    lambda grp: grp["Date_A"].map(
        lambda a: grp.loc[grp.Date_B < a, "Value"].head(N).mean()
    ),
).droplevel(level="Category")

print(df)

  Category     Date_A     Date_B  Value  rolling_mean
0        A 2023-07-08 2023-07-08     10           NaN
1        A 2023-07-09 2023-07-10     15          10.0
2        A 2023-07-11 2023-07-12     20          12.5
3        A 2023-07-12 2023-07-12     25          12.5
4        A 2023-07-13 2023-07-13     30          12.5
5        B 2023-07-08 2023-07-08     35           NaN
6        B 2023-07-09 2023-07-10     40          35.0
7        B 2023-07-11 2023-07-12     45          37.5
8        B 2023-07-12 2023-07-12     50          37.5
9        B 2023-07-13 2023-07-13     55          37.5

通过%%timeit进行时间比较

原始内容

%%timeit
for category in df.Category.unique():
    df_cat = df[df.Category==category]
    for idx, row in df_cat.iterrows():
        rm = df_cat[df_cat.Date_B < row.Date_A][:N].Value.mean()
        df.at[idx, 'rolling_mean'] = rm

7.42 ms ± 277 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

建议

%%timeit
df["rolling_mean"] = df.groupby("Category").apply(
    lambda grp: grp["Date_A"].map(
        lambda a: grp.loc[grp.Date_B < a, "Value"].head(N).mean()
    ),
).droplevel(level="Category")

4.55 ms ± 214 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

一些解释:

# this handles the initial for-loop over "Categories"
# it returns an Iterator of tuples with the first 
# item being the group name and the second being the dataframe group
df.groupby("Category")

# `grp` is the dataframe part returned by the `.groupby` above.
# it has the same columns as the original, but filtered down 
# to the rows in the selected group.
.apply(
    lambda grp: grp["Date_A"]...
)

# `map` iterates over the values of the declared series (i.e. "Date_A")
# `a` is one of the values in "Date_A" and it is compared to the 
# series "Date_B".
# `grp` is filtered to the rows where "Date_B" is < `a`.
# we keep the first `N` rows of the returned dataframe and take the mean
.map(
        lambda a: grp.loc[grp.Date_B < a, "Value"].head(N).mean()
    )

# we drop the newly generated "Category" index so we can join back 
# with the original `df`.
.droplevel(level="Category")

参考文献


0

我假设 Date_B 总是等于或大于 Date_A。然后你可以尝试使用 .groupby() + .expanding

df['rolling_mean_2'] = (
    df.groupby("Category")
    .expanding()["Value"]
    .apply(
        lambda x: x[df.loc[x.index, "Date_B"] < df.loc[x.index[-1], "Date_A"]][:2].mean()
    ).values
)

print(df)

输出:

  Category     Date_A     Date_B  Value  rolling_mean  rolling_mean_2
0        A 2023-07-08 2023-07-08     10           NaN             NaN
1        A 2023-07-09 2023-07-10     15          10.0            10.0
2        A 2023-07-11 2023-07-12     20          12.5            12.5
3        A 2023-07-12 2023-07-12     25          12.5            12.5
4        A 2023-07-13 2023-07-13     30          12.5            12.5
5        B 2023-07-08 2023-07-08     35           NaN             NaN
6        B 2023-07-09 2023-07-10     40          35.0            35.0
7        B 2023-07-11 2023-07-12     45          37.5            37.5
8        B 2023-07-12 2023-07-12     50          37.5            37.5
9        B 2023-07-13 2023-07-13     55          37.5            37.5

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接