Machine Learning 简明教程
Machine Learning - Grid Search
网格搜索是机器学习中的一种超参数调整技术,有助于找到给定模型的超参数的最佳组合。它的工作原理是定义一个超参数网格,然后使用超参数的所有可能的组合训练模型,以找到性能最佳的集合。
Grid Search is a hyperparameter tuning technique in Machine Learning that helps to find the best combination of hyperparameters for a given model. It works by defining a grid of hyperparameters and then training the model with all the possible combinations of hyperparameters to find the best performing set.
换句话说,网格搜索是一种穷举搜索方法,其中定义一组超参数,并在所有可能的超参数组合上执行搜索,以找到提供最佳性能的最优值。
In other words, Grid Search is an exhaustive search method where a set of hyperparameters are defined, and a search is performed over all possible combinations of these hyperparameters to find the optimal values that give the best performance.
Implementation in Python
在 Python 中,可以使用 sklearn 模块中的 GridSearchCV 类实现网格搜索。GridSearchCV 类接收模型、要调整的超参数和得分函数作为输入。然后,它在所有可能的超参数组合上执行穷举搜索,并返回提供最佳得分的最优超参数集。
In Python, Grid Search can be implemented using the GridSearchCV class from the sklearn module. The GridSearchCV class takes the model, the hyperparameters to tune, and a scoring function as input. It then performs an exhaustive search over all possible combinations of hyperparameters and returns the best set of hyperparameters that give the best score.
以下是在 Python 中使用 GridSearchCV 类实现网格搜索的一个示例 −
Here is an example implementation of Grid Search in Python using the GridSearchCV class −
Example
from sklearn.model_selection import GridSearchCV
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import make_classification
# Generate a sample dataset
X, y = make_classification(n_samples=1000, n_features=10, n_classes=2)
# Define the model and the hyperparameters to tune
model = RandomForestClassifier()
hyperparameters = {'n_estimators': [10, 50, 100], 'max_depth': [None, 5, 10]}
# Define the Grid Search object and fit the data
grid_search = GridSearchCV(model, hyperparameters, scoring='accuracy', cv=5)
grid_search.fit(X, y)
# Print the best hyperparameters and the corresponding score
print("Best hyperparameters: ", grid_search.best_params_)
print("Best score: ", grid_search.best_score_)
在这个示例中,我们定义了一个 RandomForestClassifier 模型和一组要调整的超参数,即树的数量 (n_estimators) 和每个树的最大深度 (max_depth)。然后,我们创建一个 GridSearchCV 对象,并使用 fit() 方法拟合数据。最后,我们打印最优超参数集和对应的得分。
In this example, we define a RandomForestClassifier model and a set of hyperparameters to tune, namely the number of trees (n_estimators) and the maximum depth of each tree (max_depth). We then create a GridSearchCV object and fit the data using the fit() method. Finally, we print the best set of hyperparameters and the corresponding score.
执行此代码时,将生成以下输出 −
When you execute this code, it will produce the following output −
Best hyperparameters: {'max_depth': None, 'n_estimators': 10}
Best score: 0.953