본문 바로가기
Learning-driven Methodology/ML (Machine Learning)

[Machine Learning] 그리드 탐색 (GridSearchCV)

by goatlab 2023. 7. 10.
728x90
반응형
SMALL

그리드 탐색 (GridSearchCV)

 

머신 러닝에서 하이퍼파라미터란 간단하게 말해 사용자의 입력값 또는 설정 가능한 입력값이라고 이해할 수 있다. 사용할 데이터에 따라 가장 적합한 모델과 모델의 하이퍼파라미터값이 다르다.

 

sklearn의 모듈 GridSearchCV는 머신 러닝 알고리즘에 사용되는 하이퍼 파라미터를 입력해 학습하고 검증하면서 가장 좋은 파라미터를 알려준다. 따라서, 학습하려는 하이퍼파라미터와 값 범위를 지정하기만 하면 GridSearchCV는 교차 검증을 사용하여 하이퍼파라미터 값의 가능한 모든 조합을 수행한다.

 

매개 변수

 

 

estimator 모델 객체 지정
param_grid 하이퍼파라미터 목록을 dictionary로 전달
scoring 평가 지표
cv 교차 검증시 fold 개수
n_jobs 사용할 CPU 코어 개수 (1 : 기본값, -1 : 모든 코어 사용)

 

메서드

 

fit(X, y) 학습
predict(X) 베스트 모델 예측
predict_proba(X) 베스트 모델호출

 

결과 조회 변수

 

cv_results_ 파라미터 조합별 결과 조회
best_params_ 베스트 parameter 조합 조회
best_estimator_ 베스트 모델 반환

 

예제

 

import sklearn
import numpy as np
import pandas as pd
from sklearn.datasets import load_iris
from sklearn.model_selection import GridSearchCV, train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score

iris = load_iris()
label = iris.target
data = iris.data

X_train, X_test, y_train, y_test = train_test_split(iris.data, iris.target, test_size=0.3, random_state=1)

dt_clf = DecisionTreeClassifier(random_state=1)
dt_clf.fit(X_train, y_train)
pred = dt_clf.predict(X_test)
accuracy = accuracy_score(y_test, pred)

print('예측 정확도 : {0:.4f}'.format(accuracy))
예측 정확도 : 0.9556
print('DecisionTreeClassifier 하이퍼파라미터:\n', dt_clf.get_params())
DecisionTreeClassifier 하이퍼 파라미터:
 {'ccp_alpha': 0.0, 'class_weight': None, 'criterion': 'gini', 'max_depth': None, 'max_features': None, 'max_leaf_nodes': None, 'min_impurity_decrease': 0.0, 'min_samples_leaf': 1, 'min_samples_split': 2, 'min_weight_fraction_leaf': 0.0, 'random_state': 1, 'splitter': 'best'}
param_grid = {
    'criterion':['gini','entropy'], 
    'max_depth':[None,2,3,4,5,6], 
    'max_leaf_nodes':[None,2,3,4,5,6,7], 
    'min_samples_split':[2,3,4,5,6], 
    'min_samples_leaf':[1,2,3], 
    'max_features':[None,'sqrt','log2',3,4,5]
}
grid_search = GridSearchCV(dt_clf, param_grid = param_grid, cv = 5, scoring = 'accuracy', refit=True)
grid_search.fit(X_train, y_train)

print('best parameters : ', grid_search.best_params_)
print('best score : ', round(grid_search.best_score_, 4))
best parameters :  {'criterion': 'gini', 'max_depth': None, 'max_features': None, 'max_leaf_nodes': None, 'min_samples_leaf': 1, 'min_samples_split': 2}
best score :  0.9524
df = pd.DataFrame(grid_search.cv_results_)
df
estimator = grid_search.best_estimator_

pred = estimator.predict(X_test)
print('score: ', round(accuracy_score(y_test,pred), 4))
728x90
반응형
LIST