goatlab
2024. 4. 21. 13:34
728x90
반응형
SMALL
TabNet
TabNet은 tabular 데이터의 훈련에 맞게 설계됐으며 Tree 기반 모델에서 변수의 선택 특징을 네트워크 구조에 반영한 테이블 형식 데이터 학습 아키텍처 모델이다. TabNet은 순차적인 attention을 사용하여 각 결정 단계에서 추론할 기능을 선택하고, 학습 용량이 가장 두드러진 기능에 사용되므로 해석 가능성과 보다 효율적인 학습을 가능하게 한다.
예제
pip install pytorch_tabnet
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import accuracy_score
from pytorch_tabnet.tab_model import TabNetClassifier
import numpy as np
from sklearn.datasets import load_iris
import torch
# Iris 데이터셋 로드
data = load_iris()
X = data.data
y = data.target
# 레이블 인코딩
label_encoder = LabelEncoder()
y = label_encoder.fit_transform(y)
# 학습 및 테스트 데이터셋 분리
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# TabNetClassifier 모델 초기화
model = TabNetClassifier(optimizer_fn=torch.optim.Adam,
optimizer_params=dict(lr=2e-2),
scheduler_params={"step_size":50, "gamma":0.9},
scheduler_fn=torch.optim.lr_scheduler.StepLR,
mask_type='entmax' # "sparsemax"
)
# 모델 학습
model.fit(
X_train=X_train, y_train=y_train,
eval_set=[(X_test, y_test)],
eval_name=['test'],
eval_metric=['accuracy'],
max_epochs=1000,
patience=20, # early stopping patience
batch_size=32, virtual_batch_size=16,
num_workers=0,
weights=1,
drop_last=False
)
# 테스트 데이터셋에 대한 예측
y_pred = model.predict(X_test)
# 정확도 계산
test_acc = accuracy_score(y_pred, y_test)
print(f"Test Accuracy: {test_acc:.4f}")
epoch 0 | loss: 1.24055 | test_accuracy: 0.36667 | 0:00:00s
epoch 1 | loss: 0.70991 | test_accuracy: 0.46667 | 0:00:00s
epoch 2 | loss: 0.49401 | test_accuracy: 0.56667 | 0:00:00s
epoch 3 | loss: 0.3128 | test_accuracy: 0.36667 | 0:00:00s
epoch 4 | loss: 0.18969 | test_accuracy: 0.36667 | 0:00:00s
epoch 5 | loss: 0.2834 | test_accuracy: 0.36667 | 0:00:00s
epoch 6 | loss: 0.20733 | test_accuracy: 0.36667 | 0:00:00s
epoch 7 | loss: 0.12741 | test_accuracy: 0.53333 | 0:00:00s
epoch 8 | loss: 0.22969 | test_accuracy: 0.6 | 0:00:01s
epoch 9 | loss: 0.33124 | test_accuracy: 0.53333 | 0:00:01s
epoch 10 | loss: 0.10614 | test_accuracy: 0.56667 | 0:00:01s
epoch 11 | loss: 0.49213 | test_accuracy: 0.5 | 0:00:01s
epoch 12 | loss: 0.1962 | test_accuracy: 0.7 | 0:00:01s
epoch 13 | loss: 0.2068 | test_accuracy: 0.7 | 0:00:01s
epoch 14 | loss: 0.18044 | test_accuracy: 0.7 | 0:00:02s
epoch 15 | loss: 0.13931 | test_accuracy: 0.7 | 0:00:02s
epoch 16 | loss: 0.16601 | test_accuracy: 0.7 | 0:00:02s
epoch 17 | loss: 0.18134 | test_accuracy: 0.7 | 0:00:02s
epoch 18 | loss: 0.23732 | test_accuracy: 0.7 | 0:00:02s
epoch 19 | loss: 0.26662 | test_accuracy: 0.7 | 0:00:02s
epoch 20 | loss: 0.1595 | test_accuracy: 0.7 | 0:00:02s
epoch 21 | loss: 0.1163 | test_accuracy: 0.7 | 0:00:02s
epoch 22 | loss: 0.14593 | test_accuracy: 0.7 | 0:00:03s
epoch 23 | loss: 0.06628 | test_accuracy: 0.7 | 0:00:03s
epoch 24 | loss: 0.18258 | test_accuracy: 0.7 | 0:00:03s
epoch 25 | loss: 0.15508 | test_accuracy: 0.8 | 0:00:03s
epoch 26 | loss: 0.16986 | test_accuracy: 0.8 | 0:00:03s
epoch 27 | loss: 0.27515 | test_accuracy: 0.8 | 0:00:03s
epoch 28 | loss: 0.17788 | test_accuracy: 0.73333 | 0:00:03s
epoch 29 | loss: 0.1773 | test_accuracy: 0.73333 | 0:00:03s
epoch 30 | loss: 0.0827 | test_accuracy: 0.73333 | 0:00:03s
epoch 31 | loss: 0.11442 | test_accuracy: 0.8 | 0:00:04s
epoch 32 | loss: 0.13208 | test_accuracy: 0.8 | 0:00:04s
epoch 33 | loss: 0.08241 | test_accuracy: 0.8 | 0:00:04s
epoch 34 | loss: 0.14952 | test_accuracy: 0.8 | 0:00:04s
epoch 35 | loss: 0.16785 | test_accuracy: 0.8 | 0:00:04s
epoch 36 | loss: 0.18916 | test_accuracy: 0.8 | 0:00:04s
epoch 37 | loss: 0.15334 | test_accuracy: 0.83333 | 0:00:04s
epoch 38 | loss: 0.09802 | test_accuracy: 0.86667 | 0:00:04s
epoch 39 | loss: 0.09509 | test_accuracy: 0.86667 | 0:00:04s
epoch 40 | loss: 0.07871 | test_accuracy: 0.93333 | 0:00:04s
epoch 41 | loss: 0.16731 | test_accuracy: 0.96667 | 0:00:04s
epoch 42 | loss: 0.19114 | test_accuracy: 0.93333 | 0:00:05s
epoch 43 | loss: 0.13993 | test_accuracy: 0.93333 | 0:00:05s
epoch 44 | loss: 0.22223 | test_accuracy: 0.96667 | 0:00:05s
epoch 45 | loss: 0.09935 | test_accuracy: 0.93333 | 0:00:05s
epoch 46 | loss: 0.10629 | test_accuracy: 0.96667 | 0:00:05s
epoch 47 | loss: 0.09302 | test_accuracy: 0.93333 | 0:00:05s
epoch 48 | loss: 0.22515 | test_accuracy: 0.93333 | 0:00:05s
epoch 49 | loss: 0.07658 | test_accuracy: 0.96667 | 0:00:05s
epoch 50 | loss: 0.12239 | test_accuracy: 0.93333 | 0:00:05s
epoch 51 | loss: 0.07066 | test_accuracy: 0.93333 | 0:00:05s
epoch 52 | loss: 0.15448 | test_accuracy: 0.93333 | 0:00:05s
epoch 53 | loss: 0.0605 | test_accuracy: 0.96667 | 0:00:06s
epoch 54 | loss: 0.24626 | test_accuracy: 0.96667 | 0:00:06s
epoch 55 | loss: 0.24765 | test_accuracy: 0.96667 | 0:00:06s
epoch 56 | loss: 0.05851 | test_accuracy: 0.96667 | 0:00:06s
epoch 57 | loss: 0.16453 | test_accuracy: 0.96667 | 0:00:06s
epoch 58 | loss: 0.15358 | test_accuracy: 0.96667 | 0:00:06s
epoch 59 | loss: 0.08635 | test_accuracy: 0.96667 | 0:00:06s
epoch 60 | loss: 0.05952 | test_accuracy: 0.96667 | 0:00:06s
epoch 61 | loss: 0.06295 | test_accuracy: 0.96667 | 0:00:06s
import matplotlib.pyplot as plt
train_loss = model.history['loss']
epochs = range(1, len(train_loss) + 1)
plt.figure(figsize=(10, 5))
plt.plot(epochs, train_loss, 'b', label='Training loss')
plt.title('Training Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()
train_acc = model.history['test_accuracy']
plt.figure(figsize=(10, 5))
plt.plot(epochs, train_acc, 'b', label='Training accuracy')
plt.title('Test accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.show()
https://dreamquark-ai.github.io/tabnet/
728x90
반응형
LIST