使用sklearn.AgglomerativeClustering绘制树状图谱

46

我尝试使用AgglomerativeClustering提供的children_属性构建一棵树状图,但是目前没有成功。我不能使用scipy.cluster,因为scipy中提供的凝聚聚类缺少对我来说很重要的一些选项(如指定聚类数量的选项)。如果有任何建议,我会非常感激。

    import sklearn.cluster
    clstr = cluster.AgglomerativeClustering(n_clusters=2)
    clusterer.children_

1
请发布一个代码示例以增加获得良好答案的机会。 - Fabian de Pabian
4
这回答解决了你的问题吗?链接 - rawkintrevo
5个回答

23
这是一个简单的函数,用于使用scipy的dendrogram函数绘制来自sklearn的分层聚类模型。似乎sklearn通常不直接支持绘图函数。您可以在此处找到与此plot_dendrogram代码片段相关的拉取请求的有趣讨论here
我想澄清一下,您描述的用例(定义簇数)在scipy中是可用的:在使用scipy的linkage执行层次聚类之后,可以使用fcluster将层次结构划分为任何您想要的簇数,并将簇数指定为t参数和criterion='maxclust'参数。

20

使用scipy实现的凝聚聚类算法。以下是一个示例。

from scipy.cluster.hierarchy import dendrogram, linkage

data = [[0., 0.], [0.1, -0.1], [1., 1.], [1.1, 1.1]]

Z = linkage(data)

dendrogram(Z)  
你可以在这里找到linkage的文档,以及在这里找到dendrogram的文档。

10
这个答案是有用的,因为它指出了利用scipy创建和可视化层次聚类的替代方法,所以我点了赞。但是,它并没有回答原始问题,即如何可视化由scikit-learn创建的聚类的树状图。如果您能添加一个函数,将scikit-learn的输出转换为类似于Z的数据结构,那就太好了。 - conradlee
@conradlee 实际上这就是 plot_dendrogram() 函数在这里所做的事情,除了最后一行:https://scikit-learn.org/stable/auto_examples/cluster/plot_agglomerative_dendrogram.html 而最后一行调用的 dendrogram 函数则是从 scipy.cluster.hierarchy 导入的。 - tozCSS
@tozCSS 感谢你指出这一点。现在得到最高投票的答案确实通过链接到现在包含在scikit-learn文档中的“plot_dendrogram”代码段来回答了这个问题。我很高兴看到文档已经改进了。我现在已经取消了对此处的赞成票。 - conradlee

3
我前段时间遇到了完全相同的问题。我使用软件包ete3,成功绘制了这个可恶的树状图。该软件包能够使用不同选项灵活地绘制树形结构。唯一的难点是将sklearnchildren_输出转换为可以被ete3读取和理解的Newick Tree格式。此外,我需要手动计算树枝的跨度,因为children_没有提供这些信息。以下是我使用的代码片段。它计算Newick树,然后显示ete3树数据结构。有关如何绘制的更多详细信息,请参见这里
import numpy as np
from sklearn.cluster import AgglomerativeClustering
import ete3

def build_Newick_tree(children,n_leaves,X,leaf_labels,spanner):
    """
    build_Newick_tree(children,n_leaves,X,leaf_labels,spanner)

    Get a string representation (Newick tree) from the sklearn
    AgglomerativeClustering.fit output.

    Input:
        children: AgglomerativeClustering.children_
        n_leaves: AgglomerativeClustering.n_leaves_
        X: parameters supplied to AgglomerativeClustering.fit
        leaf_labels: The label of each parameter array in X
        spanner: Callable that computes the dendrite's span

    Output:
        ntree: A str with the Newick tree representation

    """
    return go_down_tree(children,n_leaves,X,leaf_labels,len(children)+n_leaves-1,spanner)[0]+';'

def go_down_tree(children,n_leaves,X,leaf_labels,nodename,spanner):
    """
    go_down_tree(children,n_leaves,X,leaf_labels,nodename,spanner)

    Iterative function that traverses the subtree that descends from
    nodename and returns the Newick representation of the subtree.

    Input:
        children: AgglomerativeClustering.children_
        n_leaves: AgglomerativeClustering.n_leaves_
        X: parameters supplied to AgglomerativeClustering.fit
        leaf_labels: The label of each parameter array in X
        nodename: An int that is the intermediate node name whos
            children are located in children[nodename-n_leaves].
        spanner: Callable that computes the dendrite's span

    Output:
        ntree: A str with the Newick tree representation

    """
    nodeindex = nodename-n_leaves
    if nodename<n_leaves:
        return leaf_labels[nodeindex],np.array([X[nodeindex]])
    else:
        node_children = children[nodeindex]
        branch0,branch0samples = go_down_tree(children,n_leaves,X,leaf_labels,node_children[0])
        branch1,branch1samples = go_down_tree(children,n_leaves,X,leaf_labels,node_children[1])
        node = np.vstack((branch0samples,branch1samples))
        branch0span = spanner(branch0samples)
        branch1span = spanner(branch1samples)
        nodespan = spanner(node)
        branch0distance = nodespan-branch0span
        branch1distance = nodespan-branch1span
        nodename = '({branch0}:{branch0distance},{branch1}:{branch1distance})'.format(branch0=branch0,branch0distance=branch0distance,branch1=branch1,branch1distance=branch1distance)
        return nodename,node

def get_cluster_spanner(aggClusterer):
    """
    spanner = get_cluster_spanner(aggClusterer)

    Input:
        aggClusterer: sklearn.cluster.AgglomerativeClustering instance

    Get a callable that computes a given cluster's span. To compute
    a cluster's span, call spanner(cluster)

    The cluster must be a 2D numpy array, where the axis=0 holds
    separate cluster members and the axis=1 holds the different
    variables.

    """
    if aggClusterer.linkage=='ward':
        if aggClusterer.affinity=='euclidean':
            spanner = lambda x:np.sum((x-aggClusterer.pooling_func(x,axis=0))**2)
    elif aggClusterer.linkage=='complete':
        if aggClusterer.affinity=='euclidean':
            spanner = lambda x:np.max(np.sum((x[:,None,:]-x[None,:,:])**2,axis=2))
        elif aggClusterer.affinity=='l1' or aggClusterer.affinity=='manhattan':
            spanner = lambda x:np.max(np.sum(np.abs(x[:,None,:]-x[None,:,:]),axis=2))
        elif aggClusterer.affinity=='l2':
            spanner = lambda x:np.max(np.sqrt(np.sum((x[:,None,:]-x[None,:,:])**2,axis=2)))
        elif aggClusterer.affinity=='cosine':
            spanner = lambda x:np.max(np.sum((x[:,None,:]*x[None,:,:]))/(np.sqrt(np.sum(x[:,None,:]*x[:,None,:],axis=2,keepdims=True))*np.sqrt(np.sum(x[None,:,:]*x[None,:,:],axis=2,keepdims=True))))
        else:
            raise AttributeError('Unknown affinity attribute value {0}.'.format(aggClusterer.affinity))
    elif aggClusterer.linkage=='average':
        if aggClusterer.affinity=='euclidean':
            spanner = lambda x:np.mean(np.sum((x[:,None,:]-x[None,:,:])**2,axis=2))
        elif aggClusterer.affinity=='l1' or aggClusterer.affinity=='manhattan':
            spanner = lambda x:np.mean(np.sum(np.abs(x[:,None,:]-x[None,:,:]),axis=2))
        elif aggClusterer.affinity=='l2':
            spanner = lambda x:np.mean(np.sqrt(np.sum((x[:,None,:]-x[None,:,:])**2,axis=2)))
        elif aggClusterer.affinity=='cosine':
            spanner = lambda x:np.mean(np.sum((x[:,None,:]*x[None,:,:]))/(np.sqrt(np.sum(x[:,None,:]*x[:,None,:],axis=2,keepdims=True))*np.sqrt(np.sum(x[None,:,:]*x[None,:,:],axis=2,keepdims=True))))
        else:
            raise AttributeError('Unknown affinity attribute value {0}.'.format(aggClusterer.affinity))
    else:
        raise AttributeError('Unknown linkage attribute value {0}.'.format(aggClusterer.linkage))
    return spanner

clusterer = AgglomerativeClustering(n_clusters=2,compute_full_tree=True) # You can set compute_full_tree to 'auto', but I left it this way to get the entire tree plotted
clusterer.fit(X) # X for whatever you want to fit
spanner = get_cluster_spanner(clusterer)
newick_tree = build_Newick_tree(clusterer.children_,clusterer.n_leaves_,X,leaf_labels,spanner) # leaf_labels is a list of labels for each entry in X
tree = ete3.Tree(newick_tree)
tree.show()

1

对于愿意跳出Python并使用强大的D3库的人来说,使用d3.cluster()(或者,我想说,d3.tree())API实现一个漂亮且可定制的结果并不是非常困难。

请参见jsfiddle以获取演示。

children_数组幸运地作为JS数组很容易使用,唯一的中间步骤是使用d3.stratify()将其转换为分层表示。具体来说,我们需要每个节点都有一个id和一个parentId

var N = 272;  // Your n_samples/corpus size.
var root = d3.stratify()
  .id((d,i) => i + N)
  .parentId((d, i) => {
    var parIndex = data.findIndex(e => e.includes(i + N));
    if (parIndex < 0) {
      return; // The root should have an undefined parentId.
    }
    return parIndex + N;
  })(data); // Your children_

由于findIndex行,这里至少会出现O(n^2)的行为,但直到n_samples变得非常大时,这可能并不重要,在这种情况下,您可以预计算更有效的索引。

除此之外,这基本上是使用d3.cluster()的即插即用。请参阅mbostock的canonical block或我的JSFiddle。

N.B. 对于我的用例,仅显示非叶节点就足够了;对于可视化样本/叶子,稍微有些棘手,因为这些可能并不都明确地在children_数组中。


0

来自官方文档

import numpy as np

from matplotlib import pyplot as plt
from scipy.cluster.hierarchy import dendrogram
from sklearn.datasets import load_iris
from sklearn.cluster import AgglomerativeClustering


def plot_dendrogram(model, **kwargs):
    # Create linkage matrix and then plot the dendrogram

    # create the counts of samples under each node
    counts = np.zeros(model.children_.shape[0])
    n_samples = len(model.labels_)
    for i, merge in enumerate(model.children_):
        current_count = 0
        for child_idx in merge:
            if child_idx < n_samples:
                current_count += 1  # leaf node
            else:
                current_count += counts[child_idx - n_samples]
        counts[i] = current_count

    linkage_matrix = np.column_stack([model.children_, model.distances_,
                                      counts]).astype(float)

    # Plot the corresponding dendrogram
    dendrogram(linkage_matrix, **kwargs)


iris = load_iris()
X = iris.data

# setting distance_threshold=0 ensures we compute the full tree.
model = AgglomerativeClustering(distance_threshold=0, n_clusters=None)

model = model.fit(X)
plt.title('Hierarchical Clustering Dendrogram')
# plot the top three levels of the dendrogram
plot_dendrogram(model, truncate_mode='level', p=3)
plt.xlabel("Number of points in node (or index of point if no parenthesis).")
plt.show()

请注意,目前(截至scikit-learn v0.23),只有在调用AgglomerativeClustering时使用distance_threshold参数才会起作用,但是从v0.24开始,您可以通过将compute_distances设置为true来强制计算距离(请参见夜间构建文档)。

数值错误:必须设置n_clusters和distance_threshold中的一个,另一个必须为None。scikit-learn版本为0.24.2。 - rafine

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接