predict()方法有什么用途?它是否应该指出未知数据的最接近聚类?
是的,正是如此。
那么如果您执行诸如SVD之类的降维措施,如何处理新数据点?
在将其传递给.predict()之前,对未知数据应用相同的降维方法。以下是典型的工作流程:
# prerequisites:
# x_train: training data
# x_test: "unseen" testing data
# km: initialized `KMeans()` instance
# dr: initialized dimensionality reduction instance (such as `TruncatedSVD()`)
# fitting
x_dr = dr.fit_transform(x_train)
y = km.fit_predict(x_dr)
# ...
# working with unseen data (models have been fitted before)
x_dr = dr.transform(x_test)
y = km.predict(x_dr)
# ...
实际上,像fit_transform
和fit_predict
这样的方法是为了方便而存在的。y = km.fit_predict(x)
等同于y = km.fit(x).predict(x)
。
如果我们将拟合部分写成以下形式,我认为更容易理解:
# fitting
dr.fit(x_train)
x_dr = dr.transform(x_train)
km.fit(x_dr)
y = km.predict(x_dr)
.fit()
之外,这些模型在拟合和处理未见数据时使用的方法相同。.fit()
的目的是使用数据来训练模型。
- .predict()
或.transform()
的目的是将训练好的模型应用于数据。
- 如果您想在训练期间对模型进行拟合并将其应用于相同的数据,则可以使用.fit_predict()
或.fit_transform()
以方便操作。
- 在链接多个模型(例如降维和聚类)时,请在拟合和测试期间按照相同的顺序应用它们。