본문 바로가기
Visual Intelligence/Image Deep Learning

[시각 지능] Fashion MNIST

by goatlab 2022. 7. 31.
728x90
반응형
SMALL

Fashion MNIST

 

https://www.kaggle.com/datasets/zalando-research/fashionmnist

 

Fashion MNIST는 60,000개의 예제로 구성된 훈련 세트와 10,000개의 예제로 구성된 테스트 세트로 구성된 Zalando의 기사 이미지 데이터 세트이다. 각 예제는 10개 클래스의 레이블과 연결된 28x28 회색조 이미지이다.

 

import tensorflow as tf
import numpy as np

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Flatten, Dense
from tensorflow.keras.datasets import fashion_mnist

(x_train, t_train), (x_test, t_test) = fashion_mnist.load_data()   

print('')
print('x_train.shape = ', x_train.shape, ', t_train.shape = ', t_train.shape)
print('x_test.shape = ', x_test.shape, ', t_test.shape = ', t_test.shape)
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz
32768/29515 [=================================] - 0s 0us/step
40960/29515 [=========================================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-images-idx3-ubyte.gz
26427392/26421880 [==============================] - 0s 0us/step
26435584/26421880 [==============================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-labels-idx1-ubyte.gz
16384/5148 [===============================================================================================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-images-idx3-ubyte.gz
4423680/4422102 [==============================] - 0s 0us/step
4431872/4422102 [==============================] - 0s 0us/step

x_train.shape =  (60000, 28, 28) , t_train.shape =  (60000,)
x_test.shape =  (10000, 28, 28) , t_test.shape =  (10000,)

 

정규화 (Normalization)

 

# x_train, x_test 값 범위를 0 ~ 1 사이로 정규화
x_train = x_train / 255.0
x_test = x_test / 255.0

 

model 생성

 

model = Sequential()

model.add(Flatten(input_shape=(28, 28, 1)))  
model.add(Dense(100, activation='relu'))     
model.add(Dense(10, activation='softmax'))
from tensorflow.keras.optimizers import SGD

model.compile(optimizer=SGD(learning_rate=0.1), 
              loss='sparse_categorical_crossentropy', 
              metrics=['accuracy'])

model.summary()
Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 flatten (Flatten)           (None, 784)               0         
                                                                 
 dense (Dense)               (None, 100)               78500     
                                                                 
 dense_1 (Dense)             (None, 10)                1010      
                                                                 
=================================================================
Total params: 79,510
Trainable params: 79,510
Non-trainable params: 0
_________________________________________________________________

 

ReduceLROnPlateau

 

from tensorflow.keras.callbacks import ReduceLROnPlateau

reduceLR = ReduceLROnPlateau(monitor='val_loss', # val_loss 기준으로 callback 호출
                             factor=0.5,         # callback 호출시 학습률을 1/2로 줄임
                             patience=5,         # epoch 5 동안 개선되지 않으면 callback 호출
                             verbose=1)          # 로그 출력

hist = model.fit(x_train, t_train, 
                 epochs=50, validation_split=0.2,
                 callbacks=[reduceLR])

 

ModelCheckpoint

 

from tensorflow.keras.callbacks import ModelCheckpoint

file_path = './modelchpoint_test.h5'                # 저장할 file path

checkpoint = ModelCheckpoint(file_path,             # 저장할 file path
                             monitor='val_loss',    # val_loss 값이 개선되었을때 호출
                             verbose=1,             # log 출력
                             save_best_only=True,   # best 값만 저장
                             mode='auto')           # auto는 자동으로 best를 찾음

hist = model.fit(x_train, t_train, 
                 epochs=50, validation_split=0.2,
                 callbacks=[checkpoint])

 

EarlyStopping

 

from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping

file_path = './modelchpoint_test.h5'                # 저장할 file path

checkpoint = ModelCheckpoint(file_path,             # 저장할 file path
                             monitor='val_loss',    # val_loss 값이 개선되었을때 호출
                             verbose=1,             # log 출력
                             save_best_only=True,   # best 값만 저장
                             mode='auto')           # auto는 자동으로 best를 찾음

stopping = EarlyStopping(monitor='val_loss',        # 관찰대상은 val_loss 
                         patience=5,
                         verbose=1)                # 5 epoch 동안 개선되지 않으면 조기종료

hist = model.fit(x_train, t_train, 
                 epochs=50, validation_split=0.2,
                 callbacks=[checkpoint, stopping])

 

 

정확도 검증

 

model.evaluate(x_test, t_test)

 

손실 함수 그래프

 

import matplotlib.pyplot as plt

plt.title('Loss Trend')
plt.xlabel('epochs')
plt.ylabel('loss')
plt.grid()

plt.plot(hist.history['loss'], label='training loss')
plt.plot(hist.history['val_loss'], label='validation loss')
plt.legend(loc='best')

plt.show()

 

정확도 함수 그래프

 

# 정확도 함수 그래프
plt.title('Accuracy Trend')
plt.xlabel('epochs')
plt.ylabel('accuracy')
plt.grid()

plt.plot(hist.history['accuracy'], label='training accuracy')
plt.plot(hist.history['val_accuracy'], label='validation accuracy')
plt.legend(loc='best')

plt.show()
plt.figure(figsize=(8, 4))

plt.subplot(1, 2, 1)

plt.title('Loss Trend')
plt.grid()
plt.xlabel('epochs')
plt.ylabel('loss')

plt.plot(hist.history['loss'], label='training loss')
plt.plot(hist.history['val_loss'], label='validation loss')
plt.legend(loc='best')

plt.subplot(1, 2, 2)

plt.xlabel('epochs')
plt.ylabel('accuracy')
plt.grid()

plt.plot(hist.history['accuracy'], label='training accuracy')
plt.plot(hist.history['val_accuracy'], label='validation accuracy')
plt.legend(loc='best')

plt.show()

728x90
반응형
LIST