实现支持向量机 - 高效计算Gram矩阵K

5
我正在使用Python为mnist数据实现SVM,目前我正在使用cvxopt来解决QP并获取回归系数。
但我的问题是如何高效地计算K-gram矩阵。我从仅有两类(数字6和0)开始,训练样本数量少于1000个,接下来是10000个。
为了更快地计算整个1k x 1k的矩阵,我正在使用Process,并给出不同的原始数据进行计算。但仍然需要大约2分钟,这是一个rbf - 高斯核函数。(10k的还在运行中!)
如果有人已经处理过或者是Python爱好者可以帮我解决这个问题那就太好了!
附:如果有人不知道如何计算格拉姆矩阵,这里有详细说明:
它很简单:
for i in range(1k):
    for j in range(1k):
         for K[i,j] = some_fun(x[i], x[j])

其中 some_fun - 是点乘或花哨高斯。

我正在使用 Python 2.7、NumPy 和 Mac Air 4G RAM,128G 固态硬盘。

[编辑] 如果有人来到这里!是的,SVM确实需要更长时间...如果您正在进行多分类,那么您必须再次计算k-gram矩阵...所以它会花费很长时间,因此我建议实现算法并检查两次,让它在晚上运行!但你第二天肯定会看到好结果! :)


你是为了练习而实现它,还是只需要使用支持向量机?scikit learn 提供了一个支持向量机库。 - BrenBarn
1
实现SVM是一项艰巨的任务,如果你想要效率,我不会使用Python。即使你想从Python中使用它,我也会将其作为C或C++扩展来完成。 - Pedrom
@johnthexiii,你的意思是同样的for循环在C语言中会更快吗?我从未尝试过在Python中使用自定义的C语言,但我想看看它是否比这个简单的计算更快。 - code muncher
1
@codemuncher,在你开始编写任何C代码之前,你应该先查看scikits svm源代码,也许你能找出是什么导致了你的实现变慢。 - John
1
很可能是点积,您可以使用BLAS库来加速此过程,或者使用scikit,它附带了这些优化。 - Thomas Jungblut
显示剩余4条评论
1个回答

6

你正在使用numpy,对吧?使用numpy的矩阵运算一次计算整个矩阵,而不是使用缓慢的Python循环来寻找每个成对评估,可以获得大幅加速。例如,如果我们假设x是一个行实例数据矩阵(每个数据点一行,每个维度一列):

# get a matrix where the (i, j)th element is |x[i] - x[j]|^2
# using the identity (x - y)^T (x - y) = x^T x + y^T y - 2 x^T y
pt_sq_norms = (x ** 2).sum(axis=1)
dists_sq = np.dot(x, x.T)
dists_sq *= -2
dists_sq += pt_sq_norms.reshape(-1, 1)
dists_sq += pt_sq_norms

# turn into an RBF gram matrix
km = dists_sq; del dists_sq
km /= -2 * sigma**2
np.exp(km, km)  # exponentiates in-place

使用np.random.normal(size=(1000, 784))生成数据,这在我的四核i5 iMac上需要70毫秒。增加到10k数据点,则需要不到7秒钟。

sklearn.metrics.pairwise.rbf_kernel的工作方式类似,不过它还有一些额外的输入检查和对稀疏矩阵等的支持。

值得注意的是,在Python 2中,你应该循环遍历xrange(1000),而不是range(1000)。因为range会构造一个列表对象来进行循环,这需要一些时间,更重要的是内存。对于10000个数据点来说可能还好,但如果你的循环太大,就会引起严重的问题。


@yauheni_selivonchyk 我之前在这里使用了 sigma,但应该使用 sigma ** 2;这可能是问题所在。 (我还进行了编辑,使其更加高效。)如果仍然存在问题,您将需要更详细地描述问题,或者只需使用 sklearn 的 rbf_kernel(它基本上执行相同的操作)。 - Danica
我不太理解 km *= (-sigma**2 / 2) 这句话的意思。难道不应该是 km *= (-1/2*sigma**2) 吗?RBF核函数是 exp(-||x_i - x_j||^2 / 2*sigma^2),对吧? - Seb
1
@Seb 你是对的(除了你提出的修复中括号有误);我想我一开始使用了不同的参数化 exp(- gamma ||x - y||^2),但搞砸了。现在已经修复了。 - Danica

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接