본문 바로가기
Visual Intelligence/Image Deep Learning

[시각 지능] 사전 학습 모델 (Pre-Trained Model)

by goatlab 2022. 8. 20.
728x90
반응형
SMALL

사전 학습 모델 (Pre-Trained Model)

 

사전 학습 모델이란 기존에 자비어 (Xavier) 등 임의의 값으로 초기화하던 모델의 가중치들을 다른 문제 (task)에 학습시킨 가중치들로 초기화하는 방법이다.

 

test_image_dir.zip
0.08MB

from tensorflow.keras.applications import VGG16, ResNet50, MobileNet, InceptionV3

mobilenet_model = MobileNet(weights = 'imagenet', include_top = True, input_shape = (224, 224, 3))

mobilenet_model.summary()
Model: "mobilenet_1.00_224"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_1 (InputLayer)        [(None, 224, 224, 3)]     0         
                                                                 
 conv1 (Conv2D)              (None, 112, 112, 32)      864       
                                                                 
 conv1_bn (BatchNormalizatio  (None, 112, 112, 32)     128       
 n)                                                              
                                                                 
 conv1_relu (ReLU)           (None, 112, 112, 32)      0         
                                                                 
 conv_dw_1 (DepthwiseConv2D)  (None, 112, 112, 32)     288       
                                                                 
 conv_dw_1_bn (BatchNormaliz  (None, 112, 112, 32)     128       
 ation)                                                          
                                                                 
 conv_dw_1_relu (ReLU)       (None, 112, 112, 32)      0         
                                                                 
 conv_pw_1 (Conv2D)          (None, 112, 112, 64)      2048      
                                                                 
 conv_pw_1_bn (BatchNormaliz  (None, 112, 112, 64)     256       
 ation)                                                          
                                                                 
 conv_pw_1_relu (ReLU)       (None, 112, 112, 64)      0         
                                                                 
 conv_pad_2 (ZeroPadding2D)  (None, 113, 113, 64)      0         
                                                                 
 conv_dw_2 (DepthwiseConv2D)  (None, 56, 56, 64)       576       
                                                                 
 conv_dw_2_bn (BatchNormaliz  (None, 56, 56, 64)       256       
 ation)                                                          
                                                                 
 conv_dw_2_relu (ReLU)       (None, 56, 56, 64)        0         
                                                                 
 conv_pw_2 (Conv2D)          (None, 56, 56, 128)       8192      
                                                                 
 conv_pw_2_bn (BatchNormaliz  (None, 56, 56, 128)      512       
 ation)                                                          
                                                                 
 conv_pw_2_relu (ReLU)       (None, 56, 56, 128)       0         
                                                                 
 conv_dw_3 (DepthwiseConv2D)  (None, 56, 56, 128)      1152      
                                                                 
 conv_dw_3_bn (BatchNormaliz  (None, 56, 56, 128)      512       
 ation)                                                          
                                                                 
 conv_dw_3_relu (ReLU)       (None, 56, 56, 128)       0         
                                                                 
 conv_pw_3 (Conv2D)          (None, 56, 56, 128)       16384     
                                                                 
 conv_pw_3_bn (BatchNormaliz  (None, 56, 56, 128)      512       
 ation)                                                          
                                                                 
 conv_pw_3_relu (ReLU)       (None, 56, 56, 128)       0         
                                                                 
 conv_pad_4 (ZeroPadding2D)  (None, 57, 57, 128)       0         
                                                                 
 conv_dw_4 (DepthwiseConv2D)  (None, 28, 28, 128)      1152      
                                                                 
 conv_dw_4_bn (BatchNormaliz  (None, 28, 28, 128)      512       
 ation)                                                          
                                                                 
 conv_dw_4_relu (ReLU)       (None, 28, 28, 128)       0         
                                                                 
 conv_pw_4 (Conv2D)          (None, 28, 28, 256)       32768     
                                                                 
 conv_pw_4_bn (BatchNormaliz  (None, 28, 28, 256)      1024      
 ation)                                                          
                                                                 
 conv_pw_4_relu (ReLU)       (None, 28, 28, 256)       0         
                                                                 
 conv_dw_5 (DepthwiseConv2D)  (None, 28, 28, 256)      2304      
                                                                 
 conv_dw_5_bn (BatchNormaliz  (None, 28, 28, 256)      1024      
 ation)                                                          
                                                                 
 conv_dw_5_relu (ReLU)       (None, 28, 28, 256)       0         
                                                                 
 conv_pw_5 (Conv2D)          (None, 28, 28, 256)       65536     
                                                                 
 conv_pw_5_bn (BatchNormaliz  (None, 28, 28, 256)      1024      
 ation)                                                          
                                                                 
 conv_pw_5_relu (ReLU)       (None, 28, 28, 256)       0         
                                                                 
 conv_pad_6 (ZeroPadding2D)  (None, 29, 29, 256)       0         
                                                                 
 conv_dw_6 (DepthwiseConv2D)  (None, 14, 14, 256)      2304      
                                                                 
 conv_dw_6_bn (BatchNormaliz  (None, 14, 14, 256)      1024      
 ation)                                                          
                                                                 
 conv_dw_6_relu (ReLU)       (None, 14, 14, 256)       0         
                                                                 
 conv_pw_6 (Conv2D)          (None, 14, 14, 512)       131072    
                                                                 
 conv_pw_6_bn (BatchNormaliz  (None, 14, 14, 512)      2048      
 ation)                                                          
                                                                 
 conv_pw_6_relu (ReLU)       (None, 14, 14, 512)       0         
                                                                 
 conv_dw_7 (DepthwiseConv2D)  (None, 14, 14, 512)      4608      
                                                                 
 conv_dw_7_bn (BatchNormaliz  (None, 14, 14, 512)      2048      
 ation)                                                          
                                                                 
 conv_dw_7_relu (ReLU)       (None, 14, 14, 512)       0         
                                                                 
 conv_pw_7 (Conv2D)          (None, 14, 14, 512)       262144    
                                                                 
 conv_pw_7_bn (BatchNormaliz  (None, 14, 14, 512)      2048      
 ation)                                                          
                                                                 
 conv_pw_7_relu (ReLU)       (None, 14, 14, 512)       0         
                                                                 
 conv_dw_8 (DepthwiseConv2D)  (None, 14, 14, 512)      4608      
                                                                 
 conv_dw_8_bn (BatchNormaliz  (None, 14, 14, 512)      2048      
 ation)                                                          
                                                                 
 conv_dw_8_relu (ReLU)       (None, 14, 14, 512)       0         
                                                                 
 conv_pw_8 (Conv2D)          (None, 14, 14, 512)       262144    
                                                                 
 conv_pw_8_bn (BatchNormaliz  (None, 14, 14, 512)      2048      
 ation)                                                          
                                                                 
 conv_pw_8_relu (ReLU)       (None, 14, 14, 512)       0         
                                                                 
 conv_dw_9 (DepthwiseConv2D)  (None, 14, 14, 512)      4608      
                                                                 
 conv_dw_9_bn (BatchNormaliz  (None, 14, 14, 512)      2048      
 ation)                                                          
                                                                 
 conv_dw_9_relu (ReLU)       (None, 14, 14, 512)       0         
                                                                 
 conv_pw_9 (Conv2D)          (None, 14, 14, 512)       262144    
                                                                 
 conv_pw_9_bn (BatchNormaliz  (None, 14, 14, 512)      2048      
 ation)                                                          
                                                                 
 conv_pw_9_relu (ReLU)       (None, 14, 14, 512)       0         
                                                                 
 conv_dw_10 (DepthwiseConv2D  (None, 14, 14, 512)      4608      
 )                                                               
                                                                 
 conv_dw_10_bn (BatchNormali  (None, 14, 14, 512)      2048      
 zation)                                                         
                                                                 
 conv_dw_10_relu (ReLU)      (None, 14, 14, 512)       0         
                                                                 
 conv_pw_10 (Conv2D)         (None, 14, 14, 512)       262144    
                                                                 
 conv_pw_10_bn (BatchNormali  (None, 14, 14, 512)      2048      
 zation)                                                         
                                                                 
 conv_pw_10_relu (ReLU)      (None, 14, 14, 512)       0         
                                                                 
 conv_dw_11 (DepthwiseConv2D  (None, 14, 14, 512)      4608      
 )                                                               
                                                                 
 conv_dw_11_bn (BatchNormali  (None, 14, 14, 512)      2048      
 zation)                                                         
                                                                 
 conv_dw_11_relu (ReLU)      (None, 14, 14, 512)       0         
                                                                 
 conv_pw_11 (Conv2D)         (None, 14, 14, 512)       262144    
                                                                 
 conv_pw_11_bn (BatchNormali  (None, 14, 14, 512)      2048      
 zation)                                                         
                                                                 
 conv_pw_11_relu (ReLU)      (None, 14, 14, 512)       0         
                                                                 
 conv_pad_12 (ZeroPadding2D)  (None, 15, 15, 512)      0         
                                                                 
 conv_dw_12 (DepthwiseConv2D  (None, 7, 7, 512)        4608      
 )                                                               
                                                                 
 conv_dw_12_bn (BatchNormali  (None, 7, 7, 512)        2048      
 zation)                                                         
                                                                 
 conv_dw_12_relu (ReLU)      (None, 7, 7, 512)         0         
                                                                 
 conv_pw_12 (Conv2D)         (None, 7, 7, 1024)        524288    
                                                                 
 conv_pw_12_bn (BatchNormali  (None, 7, 7, 1024)       4096      
 zation)                                                         
                                                                 
 conv_pw_12_relu (ReLU)      (None, 7, 7, 1024)        0         
                                                                 
 conv_dw_13 (DepthwiseConv2D  (None, 7, 7, 1024)       9216      
 )                                                               
                                                                 
 conv_dw_13_bn (BatchNormali  (None, 7, 7, 1024)       4096      
 zation)                                                         
                                                                 
 conv_dw_13_relu (ReLU)      (None, 7, 7, 1024)        0         
                                                                 
 conv_pw_13 (Conv2D)         (None, 7, 7, 1024)        1048576   
                                                                 
 conv_pw_13_bn (BatchNormali  (None, 7, 7, 1024)       4096      
 zation)                                                         
                                                                 
 conv_pw_13_relu (ReLU)      (None, 7, 7, 1024)        0         
                                                                 
 global_average_pooling2d (G  (None, 1, 1, 1024)       0         
 lobalAveragePooling2D)                                          
                                                                 
 dropout (Dropout)           (None, 1, 1, 1024)        0         
                                                                 
 conv_preds (Conv2D)         (None, 1, 1, 1000)        1025000   
                                                                 
 reshape_2 (Reshape)         (None, 1000)              0         
                                                                 
 predictions (Activation)    (None, 1000)              0         
                                                                 
=================================================================
Total params: 4,253,864
Trainable params: 4,231,976
Non-trainable params: 21,888
_________________________________________________________________
mobilenetv2_model = MobileNetV2(weights = 'imagenet', include_top = True, input_shape = (224, 224, 3))

mobilenetv2_model.summary()
xception_model = Xception(weights = 'imagenet', include_top = True, input_shape = (299, 299, 3))

xception_model.summary()
inceptionv3_model = InceptionV3(weights = 'imagenet', include_top = True, input_shape = (299, 299, 3))

inceptionv3_model.summary()
import glob

# 테스트 이미지 가져오기
test_image_list = glob.glob('/content/test_image_dir/*.jpg')

print('total image # => ', len(test_image_list))
print(test_image_list)
total image # =>  8
['/content/test_image_dir/chihuahua.jpg', '/content/test_image_dir/fighter_plane.jpg', '/content/test_image_dir/forget_me_not.jpg', '/content/test_image_dir/transport.jpg', '/content/test_image_dir/tulip.jpg', '/content/test_image_dir/airliner.jpg', '/content/test_image_dir/yorkshire_terrier.jpg', '/content/test_image_dir/satellite.jpg']
import cv2
import numpy as np
import matplotlib.pyplot as plt

dst_img_list_224 = []
dst_img_list_299 = []
label_str_list = []

for i in range(len(test_image_list)):
  src_img = cv2.imread(test_image_list[i], cv2.IMREAD_COLOR)

  dst_img_224 = cv2.resize(src_img, dsize = (224, 224)) # MobileNet, MobileNetV2
  dst_img_299 = cv2.resize(src_img, dsize = (299, 299)) # InceptionV3, Xception

  dst_img_224 = cv2.cvtColor(dst_img_224, cv2.COLOR_BGR2RGB)
  dst_img_299 = cv2.cvtColor(dst_img_299, cv2.COLOR_BGR2RGB)

  dst_img_224 = dst_img_224 / 255.0
  dst_img_299 = dst_img_299 / 255.0

  label_str = test_image_list[i].split('/')[-1].split('.')[0].strip() # 이미지 정답 추출

  dst_img_list_224.append(dst_img_224)
  dst_img_list_299.append(dst_img_299)

  label_str_list.append(label_str) # 이미지 정답 저장

dst_img_array_224 = np.array(dst_img_list_224)
dst_img_array_299 = np.array(dst_img_list_299)

plt.figure(figsize = (8, 6))

for i in range(len(dst_img_list_224)):
  plt.subplot(2, 4, i + 1)
  plt.axis('off')

  plt.title(label_str_list[i])

  plt.imshow(dst_img_list_224[i])

plt.tight_layout()
plt.show()

mobilenet_pred = mobilenet_model.predict(dst_img_array_224)

mobilenetv2_pred = mobilenetv2_model.predict(dst_img_array_224)

inceptionv3_pred = inceptionv3_model.predict(dst_img_array_299)

xception_pred = xception_model.predict(dst_img_array_299)

print(mobilenet_pred.shape)
print(mobilenetv2_pred.shape)
print(inceptionv3_pred.shape)
print(xception_pred.shape)
(8, 1000)
(8, 1000)
(8, 1000)
(8, 1000)
from tensorflow.keras.applications.imagenet_utils import decode_predictions

mobilenet_prediction = decode_predictions(mobilenet_pred, top = 3)

mobilenetv2_prediction = decode_predictions(mobilenetv2_pred, top = 3)

inceptionv3_prediction = decode_predictions(inceptionv3_pred, top = 3)

xception_prediction = decode_predictions(xception_pred, top = 3)

print(type(mobilenet_prediction))
print(type(mobilenetv2_prediction))
print(type(inceptionv3_prediction))
print(type(xception_prediction))
Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/imagenet_class_index.json
40960/35363 [==================================] - 0s 0us/step
49152/35363 [=========================================] - 0s 0us/step
<class 'list'>
<class 'list'>
<class 'list'>
<class 'list'>
print('MobileNet Predition Result')

for i in range(len(mobilenet_prediction)):
  print('='*26)
  print('label', label_str_list[i])
  print(mobilenet_prediction[i])
MobileNet Predition Result
==========================
label chihuahua
[('n02086910', 'papillon', 0.7842499), ('n02085782', 'Japanese_spaniel', 0.09525943), ('n02085620', 'Chihuahua', 0.076303236)]
==========================
label fighter_plane
[('n04552348', 'warplane', 0.9673468), ('n04008634', 'projectile', 0.009557659), ('n02687172', 'aircraft_carrier', 0.007976011)]
==========================
label forget_me_not
[('n03476684', 'hair_slide', 0.21029857), ('n11939491', 'daisy', 0.1792231), ('n02219486', 'ant', 0.15513411)]
==========================
label transport
[('n03773504', 'missile', 0.42889836), ('n04552348', 'warplane', 0.4233392), ('n04266014', 'space_shuttle', 0.069459476)]
==========================
label tulip
[('n12620546', 'hip', 0.39781347), ('n12057211', "yellow_lady's_slipper", 0.14410336), ('n03942813', 'ping-pong_ball', 0.044616114)]
==========================
label airliner
[('n02690373', 'airliner', 0.97014165), ('n04592741', 'wing', 0.029005343), ('n04552348', 'warplane', 0.0005167596)]
==========================
label yorkshire_terrier
[('n02094433', 'Yorkshire_terrier', 0.8546715), ('n02087046', 'toy_terrier', 0.044203714), ('n02112706', 'Brabancon_griffon', 0.042138748)]
==========================
label satellite
[('n04286575', 'spotlight', 0.30719888), ('n04258138', 'solar_dish', 0.10327921), ('n04266014', 'space_shuttle', 0.08793576)]
print('MobileNetV2 Predition Result')

for i in range(len(mobilenetv2_prediction)):
  print('='*26)
  print('label', label_str_list[i])
  print(mobilenetv2_prediction[i])
print('Xception Predition Result')

for i in range(len(xception_prediction)):
  print('='*26)
  print('label', label_str_list[i])
  print(xception_prediction[i])
print('InceptionV3 Predition Result')

for i in range(len(inceptionv3_prediction)):
  print('='*26)
  print('label', label_str_list[i])
  print(inceptionv3_prediction[i])
728x90
반응형
LIST