如何确定SVM中非线性核的特征重要性

6
我正在使用以下代码进行特征重要性计算。
from matplotlib import pyplot as plt
from sklearn import svm

def features_importances(coef, names):
    imp = coef
    imp,names = zip(*sorted(zip(imp,names)))
    plt.barh(range(len(names)), imp, align='center')
    plt.yticks(range(len(names)), names)
    plt.show()

features_names = ['input1', 'input2']
svm = svm.SVC(kernel='linear')
svm.fit(X, Y)
feature_importances(svm.coef_, features_names)

我如何计算非线性内核的特征重要性,该内核在给定示例中未给出预期结果。


请查看此主题:https://dev59.com/fVgR5IYBdhLWcg3wM7Hi#41601281。 - Jakub Macina
3个回答

1
简短回答:目前的库无法实现此功能。可以找出线性SVM的特征重要性,但不能找出非线性SVM的特征重要性。原因是当SVM为非线性时,数据集被映射到一个高维空间中,这与父数据集非常不同,得到的超平面也是在这个高维数据上获得的,因此属性与父数据集不同,因此无法找出此SVM相对于父数据集特征的重要性。

0
一个 N x N 的卷积核结果是不可逆的,只能被追踪!请检查是否可以使用梯度。这些通常应该追踪计算。我猜你需要脉冲响应后的迹线来衡量重要性。因此,如果您输入一堆1,则需要追踪这些。
我对SciKit-Learn的实现不是很深入,也不知道尝试获取轨迹是否有意义。但是,在这一点上,您将响应追溯到特征后,它应该为您提供重要性。
然而,任何梯度下降都不是专门用于直接跟踪输入而不是导致特定输出的参数的。
您必须找到与响应相关的核心w.r.t.回传参数(给定响应本身的核心参数梯度)。
因为这可能甚至是不可能或绝对复杂的,所以我建议使用任何可以替代带来良好结果的方法。例如,在样本的不同维度之间使用核函数,而不是在每个单独样本之间使用核函数。或者一些响应函数,可以很好地缩放您的特征。

另一方面,已经有几个库可以做这样的事情。例如:sklearn.inspection.permutation_importance 或者 SHAP 包。 - MARKUS Meister

0

你不能直接提取SVM的特征重要性。但是,你可以使用sklearn中的permutation_importance来获取它。

这里有一个例子:

from sklearn.svm import SVC
from sklearn.inspection import permutation_importance
import numpy as np
import matplotlib.pyplot as plt


svm = SVC(kernel='poly')
svm.fit(X, Y)

perm_importance = permutation_importance(svm, X, Y)

# Making the sum of feature importance being equal to 1.0,
# so feature importance can be understood as percentage
perm_importance_normalized = perm_importance.importances_mean/perm_importance.importances_mean.sum()

# Feature's name (considering your X a DataFrame)
feature_names = X.columns
features = np.array(feature_names)

# Sort to plot in order of importance
sorted_idx = perm_importance_normalized.argsort()

# Plotting
plt.figure(figsize=(13,5))
plt.title('Feature Importance',fontsize=20)
plt.barh(features[sorted_idx], perm_importance_normalized[sorted_idx], color='b', align='center')
plt.xlabel('Relative Importance', fontsize=15)
plt.xticks(fontsize=15)
plt.yticks(fontsize=15)

for index, value in enumerate(perm_importance_normalized[sorted_idx]):
    plt.text(value, index,
             str(round(value,2)), fontsize=15)

plt.show()


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接