본문 바로가기
Visual Intelligence/Image Deep Learning

[시각 지능] 사전 학습된 CIFAR-10 모델로 이미지 예측

by goatlab 2022. 8. 13.
728x90
반응형
SMALL

사전 학습된 CIFAR-10 모델로 이미지 예측

 

cifar10_accuracy_80.h5
7.91MB

import numpy as np
import tensorflow as tf
from tensorflow.keras.datasets import cifar10
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPool2D
from tensorflow.keras.layers import Flatten, Dense, Dropout
import matplotlib.pyplot as plt

# 사전학습 모델 로드
try:
    cnn = tf.keras.models.load_model('./cifar10_accuracy_80.h5')
    print('pre-trained model is loaded.')
    
except Exception as err:
    print(str(err))

 

캐글의 cifar10 데이터셋이나 크롤링으로 이미지를 수집한다.

 

https://www.kaggle.com/datasets?search=cifar10 

 

Find Open Datasets and Machine Learning Projects | Kaggle

Download Open Datasets on 1000s of Projects + Share Projects on One Platform. Explore Popular Topics Like Government, Sports, Medicine, Fintech, Food, More. Flexible Data Ingestion.

www.kaggle.com

import cv2

src_img1 = cv2.imread('./cat_1.jpg', cv2.IMREAD_COLOR)
src_img2 = cv2.imread('./dog_2.jpg', cv2.IMREAD_COLOR)
src_img3 = cv2.imread('./bird_3.png', cv2.IMREAD_COLOR)
src_img4 = cv2.imread('./airplane_4.jpg', cv2.IMREAD_COLOR)

# 채널순서 변경된 이미지 변환
dst_img1 = cv2.cvtColor(src_img1, cv2.COLOR_BGR2RGB)
dst_img2 = cv2.cvtColor(src_img2, cv2.COLOR_BGR2RGB)
dst_img3 = cv2.cvtColor(src_img3, cv2.COLOR_BGR2RGB)
dst_img4 = cv2.cvtColor(src_img4, cv2.COLOR_BGR2RGB)

print(type(src_img1), src_img1.shape, type(dst_img1), dst_img1.shape)
<class 'numpy.ndarray'> (32, 32, 3) <class 'numpy.ndarray'> (32, 32, 3)
import glob

test_image_data_list = glob.glob('test_image/*')

print(test_image_data_list)
# 정답 추출
label_list = []

for index in range(len(test_image_data_list)):
    label_list.append(test_image_data_list[index].split('/')[1].split('.')[0].strip())

print(label_list)
import matplotlib.pyplot as plt

plt.figure(figsize=(8,8))

plt.subplot(1,4,1)
plt.imshow(dst_img1)
plt.subplot(1,4,2)
plt.imshow(dst_img2)
plt.subplot(1,4,3)
plt.imshow(dst_img3)
plt.subplot(1,4,4)
plt.imshow(dst_img4)

plt.tight_layout()
plt.show()

# 학습데이터 크기에 맞게 resize
dst_img1 = cv2.resize(dst_img1, dsize=(32,32))
dst_img2 = cv2.resize(dst_img2, dsize=(32,32))
dst_img3 = cv2.resize(dst_img3, dsize=(32,32))
dst_img4 = cv2.resize(dst_img4, dsize=(32,32))

# 정규화
src_img1 = src_img1 / 255.0
dst_img1 = dst_img1 / 255.0

src_img2 = src_img2 / 255.0
dst_img2 = dst_img2 / 255.0

src_img3 = src_img3 / 255.0
dst_img3 = dst_img3 / 255.0

src_img4 = src_img4 / 255.0
dst_img4 = dst_img4 / 255.0

print(src_img1.shape, dst_img1.shape)
print(src_img2.shape, dst_img2.shape)
print(src_img3.shape, dst_img3.shape)
print(src_img4.shape, dst_img4.shape)
(32, 32, 3) (32, 32, 3)
(32, 32, 3) (32, 32, 3)
(32, 32, 3) (32, 32, 3)
(32, 32, 3) (32, 32, 3)
# matplotlib 이용해서 이미지 출력
plt.figure(figsize=(8,8))

plt.subplot(1,4,1)
plt.imshow(dst_img1)
plt.subplot(1,4,2)
plt.imshow(dst_img2)
plt.subplot(1,4,3)
plt.imshow(dst_img3)
plt.subplot(1,4,4)
plt.imshow(dst_img4)

plt.tight_layout()
plt.show()
# 예측
test_image_list = []

test_image_list.append(dst_img1)
test_image_list.append(dst_img2)
test_image_list.append(dst_img3)
test_image_list.append(dst_img4)

test_image_array = np.array(test_image_list)

print(test_image_array.shape)
# 이미지 예측을 위한 class name 정의

class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

pred = cnn.predict(test_image_array)

print(pred.shape)
(4, 10)
for index in range(len(pred)):
    class_index = np.argmax(pred[index])
    print('prediction => ',class_names[class_index], pred[index].max())
# 상위 3개 예측 값
top3 = 3

for index in range(len(pred)):
    sorted_index = pred[index].argsort()  # 오름차순으로 인덱스 정렬
    sorted_index = sorted_index[::-1]     # 내림차순으로 인덱스 정렬

    print('=====================================')
    print(sorted_index, ', label = ', label_list[index])
    
    for j in range(top3):
        pred_val = pred[index, sorted_index[j]]
        class_index = sorted_index[j]
        print('prediction => ',class_names[class_index], pred_val)
728x90
반응형
LIST