Machine Learning 简明教程
Machine Learning - Agglomerative Clustering
凝聚层次聚类是一种分层聚类算法,它以每个数据点作为其自己的簇开始,并迭代地合并最接近的簇,直到达到停止标准。它是一种自下而上的方法,生成一个树状图,它是一个树状图,显示了簇之间的层次关系。该算法可以使用 Python 中的 scikit-learn 库实现。
Implementation in Python
我们将使用鸢尾花数据集进行演示。第一步是导入必要的库并加载数据集。
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn.cluster import AgglomerativeClustering
from scipy.cluster.hierarchy import dendrogram, linkage
iris = load_iris()
X = iris.data
y = iris.target
下一步是创建包含每对簇之间距离的连结矩阵。我们可以使用 scipy.cluster.hierarchy 模块中的 linkage 函数创建连结矩阵。
Z = linkage(X, 'ward')
“ward”方法用于计算簇之间的距离。它最小化了正在合并的簇之间的距离的方差。
我们可以使用同一模块中的 dendrogram 函数可视化树状图。
plt.figure(figsize=(7.5, 3.5))
plt.title("Iris Dendrogram")
dendrogram(Z)
plt.show()
生成树状图(参见以下绘图)显示了簇之间的层次关系。我们可以看到,该算法首先合并了最接近的簇,并且当我们向上移动树时,簇之间的距离会增加。
最后一步是应用聚类算法并提取簇标签。我们可以使用 sklearn.cluster 模块中的 AgglomerativeClustering 类来应用算法。
model = AgglomerativeClustering(n_clusters=3)
model.fit(X)
labels = model.labels_
n_clusters 参数指定从数据中提取的簇的数量。在本例中,我们指定 n_clusters=3,因为我们知道鸢尾花数据集有三个类别。
我们可以使用散点图来可视化生成簇。
plt.figure(figsize=(7.5, 3.5))
plt.scatter(X[:, 0], X[:, 1], c=labels)
plt.xlabel("Sepal length")
plt.ylabel("Sepal width")
plt.title("Agglomerative Clustering Results")
plt.show()
生成的图显示了算法识别的三个簇。我们可以看到,算法已成功地将数据点分离到它们各自的类中。
Example
以下是 Agglomerative Clustering 在 Python 中的完整实现:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn.cluster import AgglomerativeClustering
from scipy.cluster.hierarchy import dendrogram, linkage
# Load the Iris dataset
iris = load_iris()
X = iris.data
y = iris.target
Z = linkage(X, 'ward')
# Plot the dendogram
plt.figure(figsize=(7.5, 3.5))
plt.title("Iris Dendrogram")
dendrogram(Z)
plt.show()
# create an instance of the AgglomerativeClustering class
model = AgglomerativeClustering(n_clusters=3)
# fit the model to the dataset
model.fit(X)
labels = model.labels_
# Plot the results
plt.figure(figsize=(7.5, 3.5))
plt.scatter(X[:, 0], X[:, 1], c=labels)
plt.xlabel("Sepal length")
plt.ylabel("Sepal width")
plt.title("Agglomerative Clustering Results")
plt.show()