신경망 모델
신경망은 데이터에 대한 연산을 수행하는 계층 (layer) / 모듈(module)로 구성되어 있다. torch.nn 네임스페이스는 신경망을 구성하는데 필요한 모든 구성 요소를 제공한다. PyTorch의 모든 module은 nn.Module 의 하위 클래스 (subclass)이다. 신경망은 다른 module (계층)로 구성된 module이다. 이 중첩된 구조는 복잡한 아키텍처를 쉽게 구축하고 관리할 수 있다.
# FashionMNIST 데이터셋의 이미지들을 분류하는 신경망
import os
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
학습을 위한 장치 얻기
가능한 경우 GPU와 같은 하드웨어 가속기에서 모델을 학습한다. torch.cuda를 사용할 수 있는지 확인하고 그렇지 않으면 CPU를 계속 사용한다.
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using {device} device')
Out:
Using cuda device
클래스 정의
신경망 모델을 nn.Module 의 subclass로 정의하고, __init__ 에서 신경망 layer를 초기화한다. nn.Module 을 상속받은 모든 class는 forward 방법에 입력 데이터에 대한 연산들을 구현한다.
class NeuralNetwork(nn.Module):
def __init__(self):
super(NeuralNetwork, self).__init__()
self.flatten = nn.Flatten()
self.linear_relu_stack = nn.Sequential(
nn.Linear(28*28, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 10),
)
def forward(self, x):
x = self.flatten(x)
logits = self.linear_relu_stack(x)
return logits
NeuralNetwork의 인스턴스 (instance)를 생성하고 이를 device로 이동한 뒤, 구조 (structure)를 출력한다.
model = NeuralNetwork().to(device)
print(model)
Out:
NeuralNetwork(
(flatten): Flatten(start_dim=1, end_dim=-1)
(linear_relu_stack): Sequential(
(0): Linear(in_features=784, out_features=512, bias=True)
(1): ReLU()
(2): Linear(in_features=512, out_features=512, bias=True)
(3): ReLU()
(4): Linear(in_features=512, out_features=10, bias=True)
)
)
모델을 사용하기 위해 입력 데이터를 전달한다. 이는 일부 백그라운드 연산들과 함께 모델의 forward를 실행한다. 여기서 model.forward() 를 직접 호출하면 안된다. 모델에 입력을 호출하면 각 분류 (class)에 대한 원시 (raw) 예측값이 있는 10차원 tensor가 반환된다. raw 예측값을 nn.Softmax module의 instance에 통과시켜 예측 확률을 얻는다.
X = torch.rand(1, 28, 28, device=device)
logits = model(X)
pred_probab = nn.Softmax(dim=1)(logits)
y_pred = pred_probab.argmax(1)
print(f"Predicted class: {y_pred}")
Out:
Predicted class: tensor([6], device='cuda:0')
https://tutorials.pytorch.kr/beginner/basics/buildmodel_tutorial.html
'Python Library > PyTorch' 카테고리의 다른 글
[PyTorch] 모델 매개변수 (Parameter) (0) | 2022.01.13 |
---|---|
[PyTorch] 모델 계층 (Layer) (0) | 2022.01.13 |
[PyTorch] 변형 (Transform) (0) | 2022.01.13 |
[PyTorch] DATASET / DATALOADER (0) | 2022.01.13 |
[PyTorch] NumPy 변환 (Bridge) (0) | 2022.01.13 |