본문 바로가기
AI-driven Methodology/Artificial Intelligence

[AI] 콜백 (Callback)

by goatlab 2022. 7. 31.
728x90
반응형
SMALL

콜백 (Callback)

 

https://towardsdatascience.com/a-practical-introduction-to-keras-callbacks-in-tensorflow-2-705d0c584966

 

TensorFlow 콜백(callback)은 모델의 학습 방향, 저장 시점, 학습 정지 시점 등에 관한 상황을 모니터링 하기 위해 주로 사용된다.

 

  • 모델이 학습을 시작하면 학습이 완료될 때까지 수행할 것이 없다. 따라서, 이를 해결하고자 존재하는 것이 콜백 함수이다.
  • 예를 들어, 학습 도중에 학습율 (learning rate)을 변화시키거나 val_loss가 개선되지 않으면 학습 도중에 학습을 멈추게 하는 등의 작업을 할 수 있다.
  • TensorFlow에서 사용되는 대표적인 콜백 함수는 ReduceLROnPlateau, ModelCheckpoint, EarlyStopping 등이 있다.

 

ReduceLROnPlateau

 

모델의 성능 개선이 없을 경우, 학습율 (Learning Rate)를 조절해 모델의 개선을 유도하는 콜 백함수. factor 파라미터를 통해서 학습율을 조정한다 (factor < 1.0).

 

from tensorflow.keras.callbacks import ReduceLROnPlateau

reduceLR = ReduceLROnPlateau(monitor = 'val_loss', # val_loss 기준으로 callback 호출
                             factor = 0.5, # callback 호출시 학습률을 1/2로 줄임
                             patience = 5, # epoch 5동안 개선되지 않으면 callback 호출
                             verbose = 1) # 로그 출력

hist = model.fit(x_train, t_train, 
                 epochs = 50, validation_split = 0.2,
                 callbacks = [reduceLR])

 

ModelCheckpoint

 

모델이 학습하면서 정의한 조건을 만족했을 때 Model의 weight 값을 중간 저장한다. 학습시간이 오래 걸린다면, 모델이 개선된 validation score를 도출해낼 때마다 weight를 중간 저장한다. 도중에 memory overflow나 crash가 나더라도 다시 weight를 불러와서 학습을 이어나갈 수 있기 때문에, 시간을 save할 수 있다.

 

from tensorflow.keras.callbacks import ModelCheckpoint

file_path = './modelchpoint_test.h5' # 파일 저장 경로

checkpoint = ModelCheckpoint(file_path,
                             monitor = 'val_loss', # val_loss 값이 개선되었을 때 호출
                             verbose = 1, # log 출력
                             save_best_only = True, # best 값만 저장
                             mode = 'auto') # 자동으로 best를 찾음

hist = model.fit(x_train, t_train, 
                 epochs = 50, validation_split = 0.2,
                 callbacks = [checkpoint])

 

EarlyStopping

 

모델 성능 지표가 설정한 epoch동안 개선되지 않을 때 조기 종료할 수 있다. 일반적으로 EarlyStopping과 ModelCheckpoint 조합하여, 개선되지 않는 학습에 대한 조기 종료를 실행한다. ModelCheckpoint로부터 가장 best model을 다시 로드하여 학습을 재게할 수 있다.

 

from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping

file_path = './modelchpoint_test.h5' # 파일 저장 경로

checkpoint = ModelCheckpoint(file_path,
                             monitor = 'val_loss', # val_loss 값이 개선되었을 때 호출
                             verbose = 1, # log 출력
                             save_best_only = True, # best 값만 저장
                             mode = 'auto') # 자동으로 best를 찾음

stopping = EarlyStopping(monitor = 'val_loss', # val_loss를 관찰
                         patience = 5) # 5 epoch동안 개선되지 않으면 조기종료

hist = model.fit(x_train, t_train, 
                 epochs = 50, validation_split = 0.2,
                 callbacks = [checkpoint])
728x90
반응형
LIST