从pandas数据框中删除高度相关的列

6

我有一个名为 data 的数据框,我使用以下方法计算了它的相关矩阵:

corr = data.corr()

如果两列之间的相关系数大于0.75,我希望从数据框data中删除其中一列。我尝试了一些选项。
raw =corr[(corr.abs()>0.75) & (corr.abs() < 1.0)]

但是这并没有帮助我,我需要从原始数据中获取值为非零的列号。基本上,需要一些Python等效于以下R命令(使用函数findCorrelation)。
{hc=findCorrelation(corr,cutoff = 0.75)

hc = sort(hc)

data <- data[,-c(hc)]}

如果有人能帮我在Python Pandas中获取类似于上述R命令的命令,那将会很有帮助。
2个回答

16

使用np.eye将对角线值忽略,并查找所有具有某个绝对值大于阈值的值的列。使用逻辑否定作为索引和列的掩码。

m = ~(corr.mask(np.eye(len(corr), dtype=bool)).abs() > 0.75).any()

raw = corr.loc[m, m]

工作示例

np.random.seed([3,1415])
data = pd.DataFrame(
    np.random.randint(10, size=(10, 10)),
    columns=list('ABCDEFGHIJ'))
data

   A  B  C  D  E  F  G  H  I  J
0  0  2  7  3  8  7  0  6  8  6
1  0  2  0  4  9  7  3  2  4  3
2  3  6  7  7  4  5  3  7  5  9
3  8  7  6  4  7  6  2  6  6  5
4  2  8  7  5  8  4  7  6  1  5
5  2  8  2  4  7  6  9  4  2  4
6  6  3  8  3  9  8  0  4  3  0
7  4  1  5  8  6  0  8  7  4  6
8  3  5  8  5  1  5  1  4  3  9
9  5  5  7  0  3  2  5  8  8  9

corr = data.corr()
corr

      A     B     C     D     E     F     G     H     I     J
A  1.00  0.22  0.42 -0.12 -0.17 -0.16 -0.11  0.35  0.13 -0.06
B  0.22  1.00  0.10 -0.08 -0.18  0.07  0.33  0.12 -0.34  0.17
C  0.42  0.10  1.00 -0.08 -0.41 -0.12 -0.42  0.55  0.20  0.34
D -0.12 -0.08 -0.08  1.00 -0.05 -0.29  0.27  0.02 -0.45  0.11
E -0.17 -0.18 -0.41 -0.05  1.00  0.47  0.00 -0.38 -0.19 -0.86
F -0.16  0.07 -0.12 -0.29  0.47  1.00 -0.62 -0.67 -0.08 -0.54
G -0.11  0.33 -0.42  0.27  0.00 -0.62  1.00  0.22 -0.40  0.07
H  0.35  0.12  0.55  0.02 -0.38 -0.67  0.22  1.00  0.50  0.59
I  0.13 -0.34  0.20 -0.45 -0.19 -0.08 -0.40  0.50  1.00  0.40
J -0.06  0.17  0.34  0.11 -0.86 -0.54  0.07  0.59  0.40  1.00

m = ~(corr.mask(np.eye(len(corr), dtype=bool)).abs() > 0.5).any()
m

A     True
B     True
C    False
D     True
E    False
F    False
G    False
H    False
I     True
J    False
dtype: bool

raw = corr.loc[m, m]
raw

      A     B     D     I
A  1.00  0.22 -0.12  0.13
B  0.22  1.00 -0.08 -0.34
D -0.12 -0.08  1.00 -0.45
I  0.13 -0.34 -0.45  1.00

11
从机器学习的角度来看,应该除去除一个高度相关的列。例如,取列“C”和“H”。您应该删除其中一个,而不是两个都删除。如果您能考虑我的评论并修改代码,我将不胜感激。 - Sergey Bushmanov
@SergeyBushmanov,虽然有点晚了,但我写了一个函数,在这个答案中实现了R的findCorrelation功能,可以删除除一个高度相关的列之外的所有列。干杯! - cottontail

4
piRSquared的回答非常好,但它会删除所有与截断值以上相关性的列,这与R中的findCorrelation的行为相比有些过头了。假设这些是机器学习模型中的特征,我们需要删除足够多的列,以使列之间的成对相关系数小于某个截断点(可能存在多重共线性等问题)。删除太多可能会对建立在这些数据上的任何模型造成损害。正如Sergey Bushmanov在评论中提到的那样,在列CH之间,只应该删除一个。
R的caret::findCorrelation的Python实现

R的caret::findCorrelation查看每个变量的平均绝对相关性,并删除每对列中具有最大平均绝对相关性的变量。下面的函数(名为findCorrelation)实现了完全相同的逻辑。

根据相关矩阵的大小,caret::findCorrelation调用两个函数之一:完全向量化的findCorrelation_fast或循环的findCorrelation_exact(无论数据框大小如何,您都可以适当使用exact=参数调用其中任何一个)。下面的函数执行完全相同的操作。
caret::findCorrelation唯一不同的行为是它返回列名的列表,而caret::findCorrelation返回列的索引。我认为返回列名更自然,我们可以在以后传递给drop函数。
import numpy as np
import pandas as pd

def findCorrelation(corr, cutoff=0.9, exact=None):
    """
    This function is the Python implementation of the R function 
    `findCorrelation()`.
    
    Relies on numpy and pandas, so must have them pre-installed.
    
    It searches through a correlation matrix and returns a list of column names 
    to remove to reduce pairwise correlations.
    
    For the documentation of the R function, see 
    https://www.rdocumentation.org/packages/caret/topics/findCorrelation
    and for the source code of `findCorrelation()`, see
    https://github.com/topepo/caret/blob/master/pkg/caret/R/findCorrelation.R
    
    -----------------------------------------------------------------------------

    Parameters:
    -----------
    corr: pandas dataframe.
        A correlation matrix as a pandas dataframe.
    cutoff: float, default: 0.9.
        A numeric value for the pairwise absolute correlation cutoff
    exact: bool, default: None
        A boolean value that determines whether the average correlations be 
        recomputed at each step
    -----------------------------------------------------------------------------
    Returns:
    --------
    list of column names
    -----------------------------------------------------------------------------
    Example:
    --------
    R1 = pd.DataFrame({
        'x1': [1.0, 0.86, 0.56, 0.32, 0.85],
        'x2': [0.86, 1.0, 0.01, 0.74, 0.32],
        'x3': [0.56, 0.01, 1.0, 0.65, 0.91],
        'x4': [0.32, 0.74, 0.65, 1.0, 0.36],
        'x5': [0.85, 0.32, 0.91, 0.36, 1.0]
    }, index=['x1', 'x2', 'x3', 'x4', 'x5'])

    findCorrelation(R1, cutoff=0.6, exact=False)  # ['x4', 'x5', 'x1', 'x3']
    findCorrelation(R1, cutoff=0.6, exact=True)   # ['x1', 'x5', 'x4'] 
    """
    
    def _findCorrelation_fast(corr, avg, cutoff):

        combsAboveCutoff = corr.where(lambda x: (np.tril(x)==0) & (x > cutoff)).stack().index

        rowsToCheck = combsAboveCutoff.get_level_values(0)
        colsToCheck = combsAboveCutoff.get_level_values(1)

        msk = avg[colsToCheck] > avg[rowsToCheck].values
        deletecol = pd.unique(np.r_[colsToCheck[msk], rowsToCheck[~msk]]).tolist()

        return deletecol


    def _findCorrelation_exact(corr, avg, cutoff):

        x = corr.loc[(*[avg.sort_values(ascending=False).index]*2,)]

        if (x.dtypes.values[:, None] == ['int64', 'int32', 'int16', 'int8']).any():
            x = x.astype(float)

        x.values[(*[np.arange(len(x))]*2,)] = np.nan

        deletecol = []
        for ix, i in enumerate(x.columns[:-1]):
            for j in x.columns[ix+1:]:
                if x.loc[i, j] > cutoff:
                    if x[i].mean() > x[j].mean():
                        deletecol.append(i)
                        x.loc[i] = x[i] = np.nan
                    else:
                        deletecol.append(j)
                        x.loc[j] = x[j] = np.nan
        return deletecol

    
    if not np.allclose(corr, corr.T) or any(corr.columns!=corr.index):
        raise ValueError("correlation matrix is not symmetric.")
        
    acorr = corr.abs()
    avg = acorr.mean()
        
    if exact or exact is None and corr.shape[1]<100:
        return _findCorrelation_exact(acorr, avg, cutoff)
    else:
        return _findCorrelation_fast(acorr, avg, cutoff)

你可以调用findCorrelation来找到需要删除的列,并在数据框上调用drop()来删除这些列(就像在R中使用这个函数一样)。
使用piRSquared的设置,它返回以下输出。
corr = df.corr()
hc = findCorrelation(corr, cutoff=0.5)
trimmed_df = df.drop(columns=hc)

res


1
应该将这一行代码:if x[i].mean() > np.nanmean(x.drop(j)) 实际上改为:if x[i].mean() > x[j].mean()。将列i的均值与“除了列j之外的所有内容(包括i本身)”进行比较似乎有些奇怪。我在GitHub上找到了这个帖子:https://github.com/topepo/caret/issues/967 - Raimo Haikari
@RaimoHaikari,谢谢你告诉我这件事。这个未解决的问题似乎很有说服力,但奇怪的是它从未被重新讨论过。与此同时,我已经根据你的提醒改变了我的答案。 - cottontail

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接