Machine Learning 简明教程

网格搜索是机器学习中的一种超参数调整技术,有助于找到给定模型的超参数的最佳组合。它的工作原理是定义一个超参数网格,然后使用超参数的所有可能的组合训练模型,以找到性能最佳的集合。

换句话说,网格搜索是一种穷举搜索方法,其中定义一组超参数,并在所有可能的超参数组合上执行搜索,以找到提供最佳性能的最优值。

Implementation in Python

在 Python 中,可以使用 sklearn 模块中的 GridSearchCV 类实现网格搜索。GridSearchCV 类接收模型、要调整的超参数和得分函数作为输入。然后,它在所有可能的超参数组合上执行穷举搜索,并返回提供最佳得分的最优超参数集。

以下是在 Python 中使用 GridSearchCV 类实现网格搜索的一个示例 −

Example

from sklearn.model_selection import GridSearchCV
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import make_classification

# Generate a sample dataset
X, y = make_classification(n_samples=1000, n_features=10, n_classes=2)

# Define the model and the hyperparameters to tune
model = RandomForestClassifier()
hyperparameters = {'n_estimators': [10, 50, 100], 'max_depth': [None, 5, 10]}

# Define the Grid Search object and fit the data
grid_search = GridSearchCV(model, hyperparameters, scoring='accuracy', cv=5)
grid_search.fit(X, y)

# Print the best hyperparameters and the corresponding score
print("Best hyperparameters: ", grid_search.best_params_)
print("Best score: ", grid_search.best_score_)

在这个示例中,我们定义了一个 RandomForestClassifier 模型和一组要调整的超参数,即树的数量 (n_estimators) 和每个树的最大深度 (max_depth)。然后,我们创建一个 GridSearchCV 对象,并使用 fit() 方法拟合数据。最后,我们打印最优超参数集和对应的得分。

执行此代码时,将生成以下输出 −

Best hyperparameters: {'max_depth': None, 'n_estimators': 10}
Best score: 0.953