전이 학습 (Transfer Learning)
전이 학습은 신경망을 처음부터 훈련하는 대신 미리 로드된 가중치 세트로 훈련을 시작한다. 일반적으로 미리 훈련된 신경망의 최상위 레이어를 제거하고 새로운 최상위 레이어로 다시 훈련한다. 이전 신경망의 레이어는 훈련으로 인해 가중치가 변경되지 않도록 잠긴다. 새로 추가된 레이어만 학습된다.
대규모 이미지 데이터 세트에 대한 신경망을 훈련하려면 많은 컴퓨팅 성능이 필요할 수 있다. Google, Facebook, Microsoft 및 기타 기술 기업들은 다양한 애플리케이션을 위한 고품질 신경망을 훈련하기 위해 GPU 어레이를 활용하고 있다. 이러한 가중치를 신경망으로 전송하면 상당한 노력과 계산 시간을 절약할 수 있다. 사전 학습된 모델이 구현하려는 애플리케이션에 정확히 맞을 가능성은 거의 없다. 가장 가까운 사전 학습 모델을 찾고 전이 학습을 사용하는 것은 딥러닝 엔지니어에게 필수적이다.
전이 학습을 사용하여 이미지넷 신경망을 구축하는 간단한 예에서 네트워크는 네 가지 측정값을 가지고 각 관측값을 세 가지 종으로 분류한다. 하지만 나중에 네 가지 측정값과 비용을 대상으로 포함하는 데이터 집합을 받으면 이 데이터 세트는 종을 포함하지 않으므로 방금 학습한 기본 모델과 동일한 4개의 입력을 사용한다.
import pandas as pd
import io
import requests
import numpy as np
from sklearn import metrics
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Activation
from tensorflow.keras.callbacks import EarlyStopping
df = pd.read_csv("https://data.heatonresearch.com/data/t81-558/iris.csv", na_values=['NA', '?'])
# Convert to numpy - Classification
x = df[['sepal_l', 'sepal_w', 'petal_l', 'petal_w']].values
dummies = pd.get_dummies(df['species']) # Classification
species = dummies.columns
y = dummies.values
# Build neural network
model = Sequential()
model.add(Dense(50, input_dim=x.shape[1], activation='relu')) # Hidden 1
model.add(Dense(25, activation='relu')) # Hidden 2
model.add(Dense(y.shape[1], activation='softmax')) # Output
model.compile(loss='categorical_crossentropy', optimizer='adam')
model.fit(x, y, verbose=2, epochs=100)
이전에 훈련된 네트워크를 사용하여 가중치를 전이 학습을 통해 비용을 예측하는 방법을 학습할 새로운 신경망으로 옮길 수 있다. 또한, 주목할 점은 원래 신경망은 분류 네트워크였지만 이제는 회귀 신경망을 구축하는 데 사용한다는 것다. 이러한 변형은 전이 학습에서 흔히 볼 수 있다.
그 다음, 학습 세트에서 네트워크의 정확도를 평가한다.
from sklearn.metrics import accuracy_score
pred = model.predict(x)
predict_classes = np.argmax(pred, axis=1)
expected_classes = np.argmax(y, axis=1)
correct = accuracy_score(expected_classes, predict_classes)
print(f"Training Accuracy: {correct}")
5/5 [==============================] - 0s 3ms/step
Training Accuracy: 0.98
모델 요약을 보면 예상대로 이전에 정의한 세 개의 레이어를 볼 수 있다.
model.summary()
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense (Dense) (None, 50) 250
dense_1 (Dense) (None, 25) 1275
dense_2 (Dense) (None, 3) 78
=================================================================
Total params: 1603 (6.26 KB)
Trainable params: 1603 (6.26 KB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________
새 네트워크 만들기
데이터 세트에 대해 신경망을 학습시켰으므로 이 신경망의 지식을 다른 신경망으로 전송할 수 있다. 이 신경망의 일부 또는 전체 레이어에서 새로운 신경망을 만들 수 있다. 이 기술을 시연하기 위해 첫 번째 신경망의 복제품인 새로운 신경망을 만든다. 이제 원래 신경망의 모든 레이어를 새 신경망으로 옮긴다.
model2 = Sequential()
for layer in model.layers:
model2.add(layer)
model2.summary()
Model: "sequential_1"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense (Dense) (None, 50) 250
dense_1 (Dense) (None, 25) 1275
dense_2 (Dense) (None, 3) 78
=================================================================
Total params: 1603 (6.26 KB)
Trainable params: 1603 (6.26 KB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________
sanity 검사로 새로 생성된 모델의 정확도를 계산하고자 한다. 샘플 내 정확도는 새 모델이 전송한 이전 모델과 동일해야 한다.
from sklearn.metrics import accuracy_score
pred = model2.predict(x)
predict_classes = np.argmax(pred, axis=1)
expected_classes = np.argmax(y, axis=1)
correct = accuracy_score(expected_classes, predict_classes)
print(f"Training Accuracy: {correct}")
5/5 [==============================] - 0s 4ms/step
Training Accuracy: 0.98
새로 생성된 신경망의 샘플 내 정확도는 첫 번째 신경망과 동일하다. 원래 신경망에서 모든 레이어를 성공적으로 전송했다.
'DNN with Keras > Transfer Learning' 카테고리의 다른 글
조기 중지의 이점 (0) | 2024.02.13 |
---|---|
Transfer Learning for NLP with Keras (0) | 2024.02.13 |
네트워크 생성 및 가중치 전송 (0) | 2024.02.13 |
Keras Transfer Learning for Computer Vision (0) | 2024.02.13 |
회귀 네트워크 전송 (0) | 2024.02.13 |