본문 바로가기
Learning-driven Methodology/DL (Deep Learning)

TabNet

by goatlab 2024. 4. 21.
728x90
반응형
SMALL

TabNet

 

TabNet은 tabular 데이터의 훈련에 맞게 설계됐으며 Tree 기반 모델에서 변수의 선택 특징을 네트워크 구조에 반영한 테이블 형식 데이터 학습 아키텍처 모델이다. TabNet은 순차적인 attention을 사용하여 각 결정 단계에서 추론할 기능을 선택하고, 학습 용량이 가장 두드러진 기능에 사용되므로 해석 가능성과 보다 효율적인 학습을 가능하게 한다.

 

예제

 

pip install pytorch_tabnet
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import accuracy_score
from pytorch_tabnet.tab_model import TabNetClassifier
import numpy as np
from sklearn.datasets import load_iris
import torch

# Iris 데이터셋 로드
data = load_iris()
X = data.data
y = data.target

# 레이블 인코딩
label_encoder = LabelEncoder()
y = label_encoder.fit_transform(y)

# 학습 및 테스트 데이터셋 분리
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# TabNetClassifier 모델 초기화
model = TabNetClassifier(optimizer_fn=torch.optim.Adam,
                         optimizer_params=dict(lr=2e-2),
                         scheduler_params={"step_size":50, "gamma":0.9},
                         scheduler_fn=torch.optim.lr_scheduler.StepLR,
                         mask_type='entmax' # "sparsemax"
                        )

# 모델 학습
model.fit(
    X_train=X_train, y_train=y_train,
    eval_set=[(X_test, y_test)],
    eval_name=['test'],
    eval_metric=['accuracy'],
    max_epochs=1000,
    patience=20, # early stopping patience
    batch_size=32, virtual_batch_size=16,
    num_workers=0,
    weights=1,
    drop_last=False
)

# 테스트 데이터셋에 대한 예측
y_pred = model.predict(X_test)

# 정확도 계산
test_acc = accuracy_score(y_pred, y_test)
print(f"Test Accuracy: {test_acc:.4f}")
epoch 0  | loss: 1.24055 | test_accuracy: 0.36667 |  0:00:00s
epoch 1  | loss: 0.70991 | test_accuracy: 0.46667 |  0:00:00s
epoch 2  | loss: 0.49401 | test_accuracy: 0.56667 |  0:00:00s
epoch 3  | loss: 0.3128  | test_accuracy: 0.36667 |  0:00:00s
epoch 4  | loss: 0.18969 | test_accuracy: 0.36667 |  0:00:00s
epoch 5  | loss: 0.2834  | test_accuracy: 0.36667 |  0:00:00s
epoch 6  | loss: 0.20733 | test_accuracy: 0.36667 |  0:00:00s
epoch 7  | loss: 0.12741 | test_accuracy: 0.53333 |  0:00:00s
epoch 8  | loss: 0.22969 | test_accuracy: 0.6     |  0:00:01s
epoch 9  | loss: 0.33124 | test_accuracy: 0.53333 |  0:00:01s
epoch 10 | loss: 0.10614 | test_accuracy: 0.56667 |  0:00:01s
epoch 11 | loss: 0.49213 | test_accuracy: 0.5     |  0:00:01s
epoch 12 | loss: 0.1962  | test_accuracy: 0.7     |  0:00:01s
epoch 13 | loss: 0.2068  | test_accuracy: 0.7     |  0:00:01s
epoch 14 | loss: 0.18044 | test_accuracy: 0.7     |  0:00:02s
epoch 15 | loss: 0.13931 | test_accuracy: 0.7     |  0:00:02s
epoch 16 | loss: 0.16601 | test_accuracy: 0.7     |  0:00:02s
epoch 17 | loss: 0.18134 | test_accuracy: 0.7     |  0:00:02s
epoch 18 | loss: 0.23732 | test_accuracy: 0.7     |  0:00:02s
epoch 19 | loss: 0.26662 | test_accuracy: 0.7     |  0:00:02s
epoch 20 | loss: 0.1595  | test_accuracy: 0.7     |  0:00:02s
epoch 21 | loss: 0.1163  | test_accuracy: 0.7     |  0:00:02s
epoch 22 | loss: 0.14593 | test_accuracy: 0.7     |  0:00:03s
epoch 23 | loss: 0.06628 | test_accuracy: 0.7     |  0:00:03s
epoch 24 | loss: 0.18258 | test_accuracy: 0.7     |  0:00:03s
epoch 25 | loss: 0.15508 | test_accuracy: 0.8     |  0:00:03s
epoch 26 | loss: 0.16986 | test_accuracy: 0.8     |  0:00:03s
epoch 27 | loss: 0.27515 | test_accuracy: 0.8     |  0:00:03s
epoch 28 | loss: 0.17788 | test_accuracy: 0.73333 |  0:00:03s
epoch 29 | loss: 0.1773  | test_accuracy: 0.73333 |  0:00:03s
epoch 30 | loss: 0.0827  | test_accuracy: 0.73333 |  0:00:03s
epoch 31 | loss: 0.11442 | test_accuracy: 0.8     |  0:00:04s
epoch 32 | loss: 0.13208 | test_accuracy: 0.8     |  0:00:04s
epoch 33 | loss: 0.08241 | test_accuracy: 0.8     |  0:00:04s
epoch 34 | loss: 0.14952 | test_accuracy: 0.8     |  0:00:04s
epoch 35 | loss: 0.16785 | test_accuracy: 0.8     |  0:00:04s
epoch 36 | loss: 0.18916 | test_accuracy: 0.8     |  0:00:04s
epoch 37 | loss: 0.15334 | test_accuracy: 0.83333 |  0:00:04s
epoch 38 | loss: 0.09802 | test_accuracy: 0.86667 |  0:00:04s
epoch 39 | loss: 0.09509 | test_accuracy: 0.86667 |  0:00:04s
epoch 40 | loss: 0.07871 | test_accuracy: 0.93333 |  0:00:04s
epoch 41 | loss: 0.16731 | test_accuracy: 0.96667 |  0:00:04s
epoch 42 | loss: 0.19114 | test_accuracy: 0.93333 |  0:00:05s
epoch 43 | loss: 0.13993 | test_accuracy: 0.93333 |  0:00:05s
epoch 44 | loss: 0.22223 | test_accuracy: 0.96667 |  0:00:05s
epoch 45 | loss: 0.09935 | test_accuracy: 0.93333 |  0:00:05s
epoch 46 | loss: 0.10629 | test_accuracy: 0.96667 |  0:00:05s
epoch 47 | loss: 0.09302 | test_accuracy: 0.93333 |  0:00:05s
epoch 48 | loss: 0.22515 | test_accuracy: 0.93333 |  0:00:05s
epoch 49 | loss: 0.07658 | test_accuracy: 0.96667 |  0:00:05s
epoch 50 | loss: 0.12239 | test_accuracy: 0.93333 |  0:00:05s
epoch 51 | loss: 0.07066 | test_accuracy: 0.93333 |  0:00:05s
epoch 52 | loss: 0.15448 | test_accuracy: 0.93333 |  0:00:05s
epoch 53 | loss: 0.0605  | test_accuracy: 0.96667 |  0:00:06s
epoch 54 | loss: 0.24626 | test_accuracy: 0.96667 |  0:00:06s
epoch 55 | loss: 0.24765 | test_accuracy: 0.96667 |  0:00:06s
epoch 56 | loss: 0.05851 | test_accuracy: 0.96667 |  0:00:06s
epoch 57 | loss: 0.16453 | test_accuracy: 0.96667 |  0:00:06s
epoch 58 | loss: 0.15358 | test_accuracy: 0.96667 |  0:00:06s
epoch 59 | loss: 0.08635 | test_accuracy: 0.96667 |  0:00:06s
epoch 60 | loss: 0.05952 | test_accuracy: 0.96667 |  0:00:06s
epoch 61 | loss: 0.06295 | test_accuracy: 0.96667 |  0:00:06s
import matplotlib.pyplot as plt

train_loss = model.history['loss']

epochs = range(1, len(train_loss) + 1)

plt.figure(figsize=(10, 5))
plt.plot(epochs, train_loss, 'b', label='Training loss')
plt.title('Training Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

train_acc = model.history['test_accuracy']

plt.figure(figsize=(10, 5))
plt.plot(epochs, train_acc, 'b', label='Training accuracy')
plt.title('Test accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.show()

 

https://dreamquark-ai.github.io/tabnet/

 

Welcome to pytorch_tabnet’s documentation! — pytorch_tabnet documentation

© Copyright 2019, Dreamquark

dreamquark-ai.github.io

 

728x90
반응형
LIST