본문 바로가기
Python Library/PyTorch

[Pytorch] 모델 파라미터 계산

by goatlab 2023. 5. 26.
728x90
반응형
SMALL

Trainable parameters

 

 

import torch
import torch.nn as nn

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

# 예시 모델 정의
class ExampleModel(nn.Module):
    def __init__(self):
        super(ExampleModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.fc = nn.Linear(32 * 28 * 28, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

# 모델 인스턴스화 및 파라미터 계산
model = ExampleModel()
print(f"Total trainable parameters: {count_parameters(model)}")

 

Total parameters

 

import torch
import torch.nn as nn

def count_all_parameters(model):
    return sum(p.numel() for p in model.parameters())

# 예시 모델 정의
class ExampleModel(nn.Module):
    def __init__(self):
        super(ExampleModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.fc = nn.Linear(32 * 28 * 28, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

# 모델 인스턴스화 및 모든 파라미터 계산
model = ExampleModel()
print(f"Total parameters (including non-trainable): {count_all_parameters(model)}")

 

728x90
반응형
LIST