Machine Learning 简明教程
Machine Learning - K-Nearest Neighbors (KNN)
KNN 是一种有监督学习算法,可用于分类和回归问题。KNN 背后的主要思想是找到给定测试数据点的 k 个最近数据点,并使用这些最近邻域来进行预测。k 的值是一个需要调整的超参数,它表示要考虑的邻域数。
对于分类问题,KNN 算法会将测试数据点分配给在 k 个最近邻域中出现频率最高的类别。换句话说,邻域最多的类别就是预测类别。
对于回归问题,KNN 算法会将测试数据点分配给 k 个最近邻域的值的平均值。
用于度量两个数据点之间相似性的距离指标是一个影响 KNN 算法性能的重要因素。最常用的距离指标是欧几里得距离、曼哈顿距离和闵可夫斯基距离。
Working of KNN Algorithm
KNN 算法可以概括为以下步骤:
-
Load the data - 第一步是将数据集加载到内存中。这可以通过各种库(如 pandas 或 numpy)来完成。
-
Split the data - 下一步是将数据划分为训练集和测试集。训练集用于训练 KNN 算法,而测试集用于评估其性能。
-
Normalize the data - 在训练 KNN 算法之前,必须对数据进行标准化,以确保每个特征对距离指标计算的贡献相同。
-
Calculate distances - 一旦对数据进行标准化,KNN 算法就会计算测试数据点与训练集中每个数据点之间的距离。
-
Select k-nearest neighbors - KNN 算法根据前一步中计算的距离选择 k 个最近邻域。
-
Make a prediction - 对于分类问题,KNN 算法会将测试数据点分配给在 k 个最近邻域中出现频率最高的类别。对于回归问题,KNN 算法会将测试数据点分配给 k 个最近邻域的值的平均值。
-
Evaluate performance - 最后,使用各种指标(例如准确率、精确率、召回率和 F1 值)评估 KNN 算法的性能。
Implementation in Python
现在我们已经讨论了 KNN 算法的理论,让我们使用 scikit-learn 在 Python 中实现它。Scikit-learn 是一个流行的 Python 机器学习库,它提供了用于分类和回归问题的各种算法。
我们将使用鸢尾花卉数据集,这是一个流行的机器学习数据集,其中包含有关三种不同鸢尾花卉物种的信息。该数据集有四个特征,包括萼片长度、萼片宽度、花瓣长度和花瓣宽度,以及一个目标变量,即花卉种类。
要在 Python 中实现 KNN,我们需要遵循前面提到的步骤。以下是在鸢尾花卉数据集上实现 KNN 的 Python 代码:
Example
# import libraries
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score
# load the Iris dataset
iris = load_iris()
#split the data into training and test sets
X_train, X_test, y_train, y_test = train_test_split(iris.data,
iris.target, test_size=0.35, random_state=42)
#normalize the data
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)
#initialize the KNN algorithm
knn = KNeighborsClassifier(n_neighbors=5)
#train the KNN algorithm
knn.fit(X_train, y_train)
#make predictions on the test set
y_pred = knn.predict(X_test)
#evaluate the performance of the KNN algorithm
accuracy = accuracy_score(y_test, y_pred)
print("Accuracy: {:.2f}%".format(accuracy*100))
执行此代码时,将生成以下输出 −
Accuracy: 98.11%