goatlab 2024. 4. 21. 13:34
728x90
반응형
SMALL

TabNet

 

TabNet은 tabular 데이터의 훈련에 맞게 설계됐으며 Tree 기반 모델에서 변수의 선택 특징을 네트워크 구조에 반영한 테이블 형식 데이터 학습 아키텍처 모델이다. TabNet은 순차적인 attention을 사용하여 각 결정 단계에서 추론할 기능을 선택하고, 학습 용량이 가장 두드러진 기능에 사용되므로 해석 가능성과 보다 효율적인 학습을 가능하게 한다.

 

예제

 

pip install pytorch_tabnet
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import accuracy_score
from pytorch_tabnet.tab_model import TabNetClassifier
import numpy as np
from sklearn.datasets import load_iris
import torch

# Iris 데이터셋 로드
data = load_iris()
X = data.data
y = data.target

# 레이블 인코딩
label_encoder = LabelEncoder()
y = label_encoder.fit_transform(y)

# 학습 및 테스트 데이터셋 분리
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# TabNetClassifier 모델 초기화
model = TabNetClassifier(optimizer_fn=torch.optim.Adam,
                         optimizer_params=dict(lr=2e-2),
                         scheduler_params={"step_size":50, "gamma":0.9},
                         scheduler_fn=torch.optim.lr_scheduler.StepLR,
                         mask_type='entmax' # "sparsemax"
                        )

# 모델 학습
model.fit(
    X_train=X_train, y_train=y_train,
    eval_set=[(X_test, y_test)],
    eval_name=['test'],
    eval_metric=['accuracy'],
    max_epochs=1000,
    patience=20, # early stopping patience
    batch_size=32, virtual_batch_size=16,
    num_workers=0,
    weights=1,
    drop_last=False
)

# 테스트 데이터셋에 대한 예측
y_pred = model.predict(X_test)

# 정확도 계산
test_acc = accuracy_score(y_pred, y_test)
print(f"Test Accuracy: {test_acc:.4f}")
epoch 0  | loss: 1.24055 | test_accuracy: 0.36667 |  0:00:00s
epoch 1  | loss: 0.70991 | test_accuracy: 0.46667 |  0:00:00s
epoch 2  | loss: 0.49401 | test_accuracy: 0.56667 |  0:00:00s
epoch 3  | loss: 0.3128  | test_accuracy: 0.36667 |  0:00:00s
epoch 4  | loss: 0.18969 | test_accuracy: 0.36667 |  0:00:00s
epoch 5  | loss: 0.2834  | test_accuracy: 0.36667 |  0:00:00s
epoch 6  | loss: 0.20733 | test_accuracy: 0.36667 |  0:00:00s
epoch 7  | loss: 0.12741 | test_accuracy: 0.53333 |  0:00:00s
epoch 8  | loss: 0.22969 | test_accuracy: 0.6     |  0:00:01s
epoch 9  | loss: 0.33124 | test_accuracy: 0.53333 |  0:00:01s
epoch 10 | loss: 0.10614 | test_accuracy: 0.56667 |  0:00:01s
epoch 11 | loss: 0.49213 | test_accuracy: 0.5     |  0:00:01s
epoch 12 | loss: 0.1962  | test_accuracy: 0.7     |  0:00:01s
epoch 13 | loss: 0.2068  | test_accuracy: 0.7     |  0:00:01s
epoch 14 | loss: 0.18044 | test_accuracy: 0.7     |  0:00:02s
epoch 15 | loss: 0.13931 | test_accuracy: 0.7     |  0:00:02s
epoch 16 | loss: 0.16601 | test_accuracy: 0.7     |  0:00:02s
epoch 17 | loss: 0.18134 | test_accuracy: 0.7     |  0:00:02s
epoch 18 | loss: 0.23732 | test_accuracy: 0.7     |  0:00:02s
epoch 19 | loss: 0.26662 | test_accuracy: 0.7     |  0:00:02s
epoch 20 | loss: 0.1595  | test_accuracy: 0.7     |  0:00:02s
epoch 21 | loss: 0.1163  | test_accuracy: 0.7     |  0:00:02s
epoch 22 | loss: 0.14593 | test_accuracy: 0.7     |  0:00:03s
epoch 23 | loss: 0.06628 | test_accuracy: 0.7     |  0:00:03s
epoch 24 | loss: 0.18258 | test_accuracy: 0.7     |  0:00:03s
epoch 25 | loss: 0.15508 | test_accuracy: 0.8     |  0:00:03s
epoch 26 | loss: 0.16986 | test_accuracy: 0.8     |  0:00:03s
epoch 27 | loss: 0.27515 | test_accuracy: 0.8     |  0:00:03s
epoch 28 | loss: 0.17788 | test_accuracy: 0.73333 |  0:00:03s
epoch 29 | loss: 0.1773  | test_accuracy: 0.73333 |  0:00:03s
epoch 30 | loss: 0.0827  | test_accuracy: 0.73333 |  0:00:03s
epoch 31 | loss: 0.11442 | test_accuracy: 0.8     |  0:00:04s
epoch 32 | loss: 0.13208 | test_accuracy: 0.8     |  0:00:04s
epoch 33 | loss: 0.08241 | test_accuracy: 0.8     |  0:00:04s
epoch 34 | loss: 0.14952 | test_accuracy: 0.8     |  0:00:04s
epoch 35 | loss: 0.16785 | test_accuracy: 0.8     |  0:00:04s
epoch 36 | loss: 0.18916 | test_accuracy: 0.8     |  0:00:04s
epoch 37 | loss: 0.15334 | test_accuracy: 0.83333 |  0:00:04s
epoch 38 | loss: 0.09802 | test_accuracy: 0.86667 |  0:00:04s
epoch 39 | loss: 0.09509 | test_accuracy: 0.86667 |  0:00:04s
epoch 40 | loss: 0.07871 | test_accuracy: 0.93333 |  0:00:04s
epoch 41 | loss: 0.16731 | test_accuracy: 0.96667 |  0:00:04s
epoch 42 | loss: 0.19114 | test_accuracy: 0.93333 |  0:00:05s
epoch 43 | loss: 0.13993 | test_accuracy: 0.93333 |  0:00:05s
epoch 44 | loss: 0.22223 | test_accuracy: 0.96667 |  0:00:05s
epoch 45 | loss: 0.09935 | test_accuracy: 0.93333 |  0:00:05s
epoch 46 | loss: 0.10629 | test_accuracy: 0.96667 |  0:00:05s
epoch 47 | loss: 0.09302 | test_accuracy: 0.93333 |  0:00:05s
epoch 48 | loss: 0.22515 | test_accuracy: 0.93333 |  0:00:05s
epoch 49 | loss: 0.07658 | test_accuracy: 0.96667 |  0:00:05s
epoch 50 | loss: 0.12239 | test_accuracy: 0.93333 |  0:00:05s
epoch 51 | loss: 0.07066 | test_accuracy: 0.93333 |  0:00:05s
epoch 52 | loss: 0.15448 | test_accuracy: 0.93333 |  0:00:05s
epoch 53 | loss: 0.0605  | test_accuracy: 0.96667 |  0:00:06s
epoch 54 | loss: 0.24626 | test_accuracy: 0.96667 |  0:00:06s
epoch 55 | loss: 0.24765 | test_accuracy: 0.96667 |  0:00:06s
epoch 56 | loss: 0.05851 | test_accuracy: 0.96667 |  0:00:06s
epoch 57 | loss: 0.16453 | test_accuracy: 0.96667 |  0:00:06s
epoch 58 | loss: 0.15358 | test_accuracy: 0.96667 |  0:00:06s
epoch 59 | loss: 0.08635 | test_accuracy: 0.96667 |  0:00:06s
epoch 60 | loss: 0.05952 | test_accuracy: 0.96667 |  0:00:06s
epoch 61 | loss: 0.06295 | test_accuracy: 0.96667 |  0:00:06s
import matplotlib.pyplot as plt

train_loss = model.history['loss']

epochs = range(1, len(train_loss) + 1)

plt.figure(figsize=(10, 5))
plt.plot(epochs, train_loss, 'b', label='Training loss')
plt.title('Training Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

train_acc = model.history['test_accuracy']

plt.figure(figsize=(10, 5))
plt.plot(epochs, train_acc, 'b', label='Training accuracy')
plt.title('Test accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.show()

 

https://dreamquark-ai.github.io/tabnet/

 

Welcome to pytorch_tabnet’s documentation! — pytorch_tabnet documentation

© Copyright 2019, Dreamquark

dreamquark-ai.github.io

 

728x90
반응형
LIST