在TensorFlow中,Kmeans聚类是如何工作的?

3
我看到tensorflow的contrib库中有Kmeans聚类的实现。然而,我无法简单地估计2D点的簇中心。
代码:
## Generate synthetic data
N,D = 1000, 2 # number of points and dimenstinality

means = np.array([[0.5, 0.0],
                  [0, 0],
                  [-0.5, -0.5],
                  [-0.8, 0.3]])
covs = np.array([np.diag([0.01, 0.01]),
                 np.diag([0.01, 0.01]),
                 np.diag([0.01, 0.01]),
                 np.diag([0.01, 0.01])])
n_clusters = means.shape[0]

points = []
for i in range(n_clusters):
    x = np.random.multivariate_normal(means[i], covs[i], N )
    points.append(x)
points = np.concatenate(points)

## construct model
kmeans = tf.contrib.learn.KMeansClustering(num_clusters = n_clusters)
kmeans.fit(points.astype(np.float32))

我收到了以下错误信息:
InvalidArgumentError (see above for traceback): Shape [-1,2] has negative dimensions
     [[Node: input = Placeholder[dtype=DT_FLOAT, shape=[?,2], _device="/job:localhost/replica:0/task:0/cpu:0"]()]]

我猜自己做错了一些事情,但从文件中没找出来。

编辑:

我使用input_fn解决了这个问题,但是速度非常慢(我不得不将每个集群中的点数减少到10才能看到结果)。 为什么会这样,我该如何使它更快?

 def input_fn():
    return tf.constant(points, dtype=tf.float32), None

## construct model
kmeans = tf.contrib.learn.KMeansClustering(num_clusters = n_clusters, relative_tolerance=0.0001)
kmeans.fit(input_fn=input_fn)
centers = kmeans.clusters()
print(centers)

已解决:

看来需要设置相对容忍度。我只改了一行代码,现在它可以正常工作了。 kmeans = tf.contrib.learn.KMeansClustering(num_clusters = n_clusters, relative_tolerance=0.0001)


你正在运行哪个版本的TF? - Dan Salo
1个回答

0

你的原始代码在Tensorflow 1.2中返回以下错误:

    WARNING:tensorflow:From <stdin>:1: calling BaseEstimator.fit (from         
    tensorflow.contrib.learn.python.learn.estimators.estimator) with x 
    is deprecated and will be removed after 2016-12-01.
    Instructions for updating:
    Estimator is decoupled from Scikit Learn interface by moving into
    separate class SKCompat. Arguments x, y and batch_size are only
    available in the SKCompat class, Estimator will only accept input_fn.

根据您的编辑,似乎您已经发现input_fn是唯一可接受的输入。如果您真的想使用TF,我建议升级到r1.2并将Estimator包装在SKCompat类中,就像错误消息所建议的那样。否则,我会使用SKLearn包。您也可以手动在TF中实现自己的聚类算法,如this blog中所示。

谢谢。我解决了。不过还有一个问题 - 如果我的点在一个tf变量中,那么它会起作用吗?还是我需要做一些不同的事情?(比如在输入到kmeans聚类之前进行评估) - itzik Ben Shabat
估算器(不带包装器)不接受TF张量作为输入,因此占位符和变量被排除在外。因此,在输入之前进行评估应该是可行的! - Dan Salo

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接