如何在Python中对以下代码进行优化?

5
我有一个用户自定义指标需要实现,内容如下:
def metric(pred:pd.DataFrame(), valid:pd.DataFrame()):
    date_begin = valid.dt.min()
    date_end = valid.dt.max()
    x = valid[valid.label == 1].dt.min()

    # p
    p_n_tpp_df = valid[(valid.dt >= x) &\
                       (valid.dt <= x + timedelta(days=30)) &\
                       (p_n_tpp_df.label == 1)]
    p_n_pp_df =  valid[(valid.dt >= date_begin + timedelta(days=30)) &\ 
                       (valid.dt <= date_end + timedelta(days=30)) &\
                       (p_n_tpp_df.label == 1)]


    p_n_tpp = len([x for x in pred.serial_number.values\ 
                     if x in p_n_tpp_df.serial_number.unique()])
    p_n_pp = len([x for x in pred.serial_number.values\ 
                    if x in p_n_pp_df.serial_number.unique()])

    p = p_n_tpp / p_n_pp
    print('p: ', p)

    # r
    p_n_tpr_df = valid[(valid.dt >= date_begin - timedelta(days=30)) &\ 
                      (valid.dt <= date_end - timedelta(days=30)) &\
                      (p_n_tpr_df.label == 1)]
    p_n_pr_df = valid[(valid.dt >= date_begin) &\ 
                      (valid.dt <= date_end) &\ 
                      (p_n_pr_df.label == 1)]


    p_n_tpr = len([x for x in pred.serial_number.values\
                     if x in p_n_tpr_df.serial_number.unique()])
    p_n_pr = len([x for x in pred.serial_number.values\
                    if x in p_n_pr_df.serial_number.unique()])

    r = p_n_tpr / p_n_pr
    print('p: ', r)

    m = 2 * p * r / (p + r)

    return m

predvalidpd.DataFrame() 拥有相同的列,dt 没有交集。同时,valid 中所有的 serial_number 值都是 predserial_number 值的子集。
label 列只有两个值: 0 或 1。
下面是 predvalid 的样例:


print(pred.head(3))
    serial_number  dt          label  
0   123            2011-03-21  1
1   52             2011-03-22  0
2   12             2011-03-01  1
..., ...


print(pred.info())
Int64Index: 10000000 entries,
Data columns (total 3 columns):
serial_number  int32
dt             datetimes64[ns]
label          int8
..., ...

print(valid.head(3))
    serial_number  dt          label  
0   324            2011-04-22  1
1   52             2011-04-22  0
2   14             2011-04-01  1
..., ...


print(valid.info())
Int64Index: 10000000 entries,
Data columns (total 3 columns):
serial_number  int32
dt             datetimes64[ns]
label          int8

输入的 pd.DataFrame 大约有 10,000,000 组样本和 3 个特征。
当我使用它来计算这个度量时,速度非常慢,在 Intel 9600KF 上花费的时间超过了 2 小时。
因此,我想知道如何优化这段代码以节省时间成本。
提前感谢。


3
你能提供一个示例数据集吗? - Itamar Mushkin
应该将 number 改为 serial_number 吗? - Itamar Mushkin
为了帮助您,我们需要predvalid的样例。 - Itamar Mushkin
@ItamarMushkin,我已经更新了它的详细信息。感谢您的建议。 - Bowen Peng
虽然有点不相关,但是你使用的类型注释(pred:pd.DataFrame())是不正确的。类型注释应该是类型,但我相信 pd.DataFrame() 创建了一个实例。 - Zecong Hu
显示剩余5条评论
1个回答

6

以下是您代码中最重要的性能优化点:

Numpy集合逻辑

len([x for x in pred.serial_number.values\
                     if x in p_n_tpr_df.serial_number.unique()])

下面这一行代码是获取 pred.serial_numberp_n_tpr_df.serial_number 两个集合的交集大小。使用numpy代替列表推导和unique函数可以节省大量运算时间:

intersect_size = np.intersect1d(pred.serial_number.values,
                                p_n_tpr_df.serial_number.values).shape[0]

@bowen-peng,这个答案对您有帮助吗?如果是的话,您可以接受这个答案。 - hume

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接