如何在交叉验证和网格搜索中实现SMOTE

11

我对Python比较陌生,您能帮我把SMOTE的实现改进成正式的管道吗?我想要在每个k-fold迭代的训练集上应用过采样和欠采样,以便模型在平衡数据集上进行训练,并在不平衡的剩余部分上进行评估。问题是这样做时,我无法使用熟悉的sklearn接口进行评估和网格搜索。

是否有可能制作类似于model_selection.RandomizedSearchCV的东西。我对此的看法:

df = pd.read_csv("Imbalanced_data.csv") #Load the data set
X = df.iloc[:,0:64]
X = X.values
y = df.iloc[:,64]
y = y.values
n_splits = 2
n_measures = 2 #Recall and AUC
kf = StratifiedKFold(n_splits=n_splits) #Stratified because we need balanced samples
kf.get_n_splits(X)
clf_rf = RandomForestClassifier(n_estimators=25, random_state=1)
s =(n_splits,n_measures)
scores = np.zeros(s)
for train_index, test_index in kf.split(X,y):
   print("TRAIN:", train_index, "TEST:", test_index)
   X_train, X_test = X[train_index], X[test_index]
   y_train, y_test = y[train_index], y[test_index]
   sm = SMOTE(ratio = 'auto',k_neighbors = 5, n_jobs = -1)
   smote_enn = SMOTEENN(smote = sm)
   x_train_res, y_train_res = smote_enn.fit_sample(X_train, y_train)
   clf_rf.fit(x_train_res, y_train_res)
   y_pred = clf_rf.predict(X_test,y_test)
   scores[test_index,1] = recall_score(y_test, y_pred)
   scores[test_index,2] = auc(y_test, y_pred)

你解决问题了吗? - Vivek Kumar
是的,实际上您的评论对我帮助很大。非常感谢! - MLearner
你好 @VivekKumar,这个方法是否可以确保在运行K-Fold CV时,验证集不包含过采样的观察结果?我正在尝试找到一种方法,在对训练集进行训练/测试拆分并对其进行过采样之后,我的每个CV折叠的验证集都不包含来自过采样的偏差。谢谢! - thePurplePython
@thePurplePython 是的,你说得对。imblearn管道只会在训练数据上调用sample()方法,而不会在测试数据上调用。测试数据将不会发生任何改变而直接传递。 - Vivek Kumar
2个回答

17
你需要查看管道对象。imbalanced-learn拥有一个Pipeline,它扩展了scikit-learn Pipeline的功能,以适应fit_sample()和sample()方法,除了适用于scikit-learn的fit_predict()、fit_transform()和predict()方法之外。请参考这个例子:

针对您的代码,您需要这样做:

from imblearn.pipeline import make_pipeline, Pipeline

smote_enn = SMOTEENN(smote = sm)
clf_rf = RandomForestClassifier(n_estimators=25, random_state=1)

pipeline = make_pipeline(smote_enn, clf_rf)
    OR
pipeline = Pipeline([('smote_enn', smote_enn),
                     ('clf_rf', clf_rf)])

然后,您可以将此 pipeline 对象传递给GridSearchCV、RandomizedSearchCV或其他scikit-learn中的交叉验证工具作为常规对象。

kf = StratifiedKFold(n_splits=n_splits)
random_search = RandomizedSearchCV(pipeline, param_distributions=param_dist,
                                   n_iter=1000, 
                                   cv = kf)

1
我尝试从这个答案中访问链接,但是出现了404错误。 - Mariane Reis
@MarianeReis 感谢您的通知。我已经更新了链接。 - Vivek Kumar
这两个链接仍然带我到404页面。 - agent18
@agent18 链接已经再次更新,请现在检查。 - Vivek Kumar

3

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接