728x90
반응형
SMALL
라이브러리 설치
ONNX 포맷으로 저장하기 위해 다음을 실행한다. 그리고 Fire 라이브러리를 활용하면 task 별로 필요한 인자를 설정하여 CLI 기반 프로그램을 쉽고 빠르게 만들 수 있다. task를 분리하면 다양한 장점이 있다 (필요한 태스크만 수행, 트러블 슈팅 및 디버깅 용이, 유연한 자원 할당, 유지보수성, 워크플로우 관리 등).
pip install onnx onnxruntime fire
src/model/movie_predictor.py
torch(pth) 포맷으로 저장하기 아래 코드를 추가한다.
import os
import datetime
import torch
from src.utils.utils import model_dir
import torch.nn as nn
class MoviePredictor(nn.Module):
name = "movie_predictor"
def __init__(self, input_dim, num_classes):
super(MoviePredictor, self).__init__()
self.input_dim = input_dim
self.layer1 = nn.Linear(input_dim, 64)
self.layer2 = nn.Linear(64, 32)
self.layer3 = nn.Linear(32, num_classes)
self.relu = nn.ReLU()
self.dropout = nn.Dropout(0.2)
def forward(self, x):
x = self.relu(self.layer1(x))
x = self.dropout(x)
x = self.relu(self.layer2(x))
x = self.dropout(x)
x = self.layer3(x)
return x
def model_save(model, model_params, epoch, optimizer, loss, scaler, contents_id_map, ext="pth"):
save_dir = model_dir(model.name)
os.makedirs(save_dir, exist_ok=True)
current_time = datetime.datetime.now().strftime("%y%m%d%H%M%S")
dst = os.path.join(save_dir, f"E{epoch}_T{current_time}.{ext}") # ext 부분 수정
if ext == "pth":
torch.save({
"epoch": epoch,
"model_params": model_params,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"loss": loss,
"scaler": scaler,
"contents_id_map": contents_id_map,
}, dst)
elif ext == "onnx":
dummy_input = torch.randn(1, model.input_dim)
torch.onnx.export(
model,
dummy_input,
dst,
export_params=True
)
else:
raise ValueError(f"Invalid model export extension : {ext}")
src/main.py
import os
import sys
sys.path.append(
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
)
import fire
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm
from src.dataset.watch_log import get_datasets
from src.model.movie_predictor import MoviePredictor
from src.utils.utils import init_seed
from src.train.train import train
from src.evaluate.evaluate import evaluate
from src.model.movie_predictor import MoviePredictor, model_save
from utils.constant import Optimizers, Models
init_seed()
def run_train(model_name, optimizer, num_epochs=10, lr=0.001, model_ext="pth"):
Models.validation(model_name)
Optimizers.validation(optimizer)
# 데이터셋 및 DataLoader 생성
train_dataset, val_dataset, test_dataset = get_datasets()
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=0, pin_memory=False)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=0, pin_memory=False)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=0, pin_memory=False)
# 모델 초기화
model_params = {
"input_dim": train_dataset.features_dim,
"num_classes": train_dataset.num_classes
}
model_class = Models[model_name.upper()].value
model = model_class(**model_params)
# 손실 함수 및 옵티마이저 정의
criterion = nn.CrossEntropyLoss()
optimizer_class = Optimizers[optimizer.upper()].value
optimizer = optimizer_class(model.parameters(), lr=lr)
# 학습 루프
epoch = 0
train_loss = 0
num_epochs = 10
for epoch in tqdm(range(num_epochs)):
train_loss = train(model, train_loader, criterion, optimizer)
val_loss, _ = evaluate(model, val_loader, criterion)
print(f"Epoch {epoch + 1}/{num_epochs}, "
f"Train Loss: {train_loss:.4f}, "
f"Val Loss: {val_loss:.4f}, "
f"Val-Train Loss : {val_loss-train_loss:.4f}")
model_ext = "onnx" # or "pth"
model_save(
model=model,
model_params=model_params,
epoch=num_epochs,
optimizer=optimizer,
loss=train_loss,
scaler=train_dataset.scaler,
contents_id_map=train_dataset.contents_id_map,
ext=model_ext,
)
# 테스트
model.eval()
test_loss, predictions = evaluate(model, test_loader, criterion)
print(f"Test Loss : {test_loss:.4f}")
# print([train_dataset.decode_content_id(idx) for idx in predictions])
if __name__ == '__main__':
fire.Fire({
"train": run_train,
})
# CLI
python src/main.py train --model_name movie_predictor --optimizer adam --num_epochs 20 --lr 0.002
728x90
반응형
LIST
'App Programming > MLops' 카테고리의 다른 글
[MLops] 모델 추론 (0) | 2024.08.13 |
---|---|
[MLops] 학습 결과 기록하기 (0) | 2024.08.12 |
[MLops] 모델 학습 및 평가 (0) | 2024.08.12 |
[MLops] 모델 훈련 (0) | 2024.08.09 |
[MLops] TMDB API 데이터 수집 및 전처리 (0) | 2024.08.09 |