본문 바로가기
Visual Intelligence/Image Deep Learning

[시각 지능] Coffee Classification

by goatlab 2022. 8. 28.
728x90
반응형
SMALL

Coffee Classification

 

커피의 종류는 에스프레소, 아메리카노, 카푸치노, 카페 라떼, 카페 모카 등 다양한 메뉴가 있다.

 

구글에서 크롤링을 통해 이미지를 수집한다. 아메리카노, 라떼, 모카를 분류한다.

 

import os
import glob
import math
import shutil
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Dropout, Dense, GlobalAveragePooling2D
from tensorflow.keras.applications import MobileNet 
from tensorflow.keras.models import Sequential
from tensorflow.keras.preprocessing.image import ImageDataGenerator

ROOT_DIR = 'C:/Users/'

DATA_ROOT_DIR = os.path.join(ROOT_DIR, 'coffee')

TRAIN_DATA_ROOT_DIR = os.path.join(DATA_ROOT_DIR, 'train')

TEST_DATA_ROOT_DIR = os.path.join(DATA_ROOT_DIR, 'test')

 

데이터 확인

 

total_file_list = glob.glob(os.path.join(DATA_ROOT_DIR, '*/'))

print(total_file_list)

# 정답 리스트 추출 (디렉토리 이름이 정답)
label_name_list = [ file_name.split('\\')[2].strip()  for file_name in total_file_list  if os.path.isdir(file_name) == True ]

print(label_name_list)
['americano', 'latte', 'mocha']
for label_name in label_name_list:
    src_dir_path = os.path.join(DATA_ROOT_DIR, 'dataset' +'/'+ label_name)
    dst_dir_path = os.path.join(DATA_ROOT_DIR, 'train'+'/' + label_name)

    try:
        shutil.copytree(src_dir_path, dst_dir_path)
        print(label_name+' copytree is done.')

    except Exception as err:
        print(str(err))
americano copytree is done.
latte copytree is done.
mocha copytree is done.
TRAIN_DATA_ROOT_DIR = os.path.join(DATA_ROOT_DIR, 'train')

TEST_DATA_ROOT_DIR = os.path.join(DATA_ROOT_DIR, 'test')

# train 정답 및 전체 데이터 개수 확인
train_label_name_list = os.listdir(TRAIN_DATA_ROOT_DIR)

print(train_label_name_list)

for label_name in train_label_name_list:
    print('train label : ', label_name,' => ', len(os.listdir(os.path.join(TRAIN_DATA_ROOT_DIR, label_name))))

print('=====================================================')
['americano', 'latte', 'mocha']
train label :  americano  =>  381
train label :  latte  =>  400
train label :  mocha  =>  304
=====================================================

 

데이터 전처리

 

# test dir 생성
if not os.path.exists(TEST_DATA_ROOT_DIR):
    os.mkdir(TEST_DATA_ROOT_DIR)
    print(TEST_DATA_ROOT_DIR + ' is created.')
    
else:
    print(TEST_DATA_ROOT_DIR + ' already exists')

# test dir 하위 디렉토리에 정답 디렉토리 생성
for label_name in label_name_list:
    if not os.path.exists(os.path.join(TEST_DATA_ROOT_DIR, label_name)):

        os.mkdir(os.path.join(TEST_DATA_ROOT_DIR, label_name))
        print(os.path.join(TEST_DATA_ROOT_DIR, label_name) + ' is created.')

    else:
        print(os.path.join(TEST_DATA_ROOT_DIR, label_name) + ' already exists.')
import random

# 파일 move 비율
MOVE_RATIO = 0.2  # train : test = 80 : 20, train 데이터 20% 데이터를 test 데이터로 사용

# 파일 move train_data_dir => test_data_dir
label_name_list = os.listdir(TRAIN_DATA_ROOT_DIR)

for label_name in label_name_list:
    # 파일 move 하기 위한 src_dir_path, dst_dir_path 설정
    src_dir_path = os.path.join(TRAIN_DATA_ROOT_DIR,label_name)  
    dst_dir_path = os.path.join(TEST_DATA_ROOT_DIR,label_name)  

    train_data_file_list = os.listdir(src_dir_path)

    print('========================================================================')
    print('total [%s] data file nums => [%s]' % (label_name ,len(train_data_file_list)))

    # data shuffle
    random.shuffle(train_data_file_list)
    print('train data shuffle is done.')

    split_num = int(MOVE_RATIO*len(train_data_file_list))

    print('split nums => ', split_num)

    # extract test data from train data
    test_data_file_list = train_data_file_list[0:split_num]

    move_nums = 0

    for test_data_file in test_data_file_list:
        try:
            shutil.move(os.path.join(src_dir_path, test_data_file),
                        os.path.join(dst_dir_path, test_data_file))   
        except Exception as err:
            print(str(err))

        move_nums = move_nums + 1

    print('total move nums => ', move_nums)
    print('========================================================================')
# train 파일 개수 확인
label_name_list = os.listdir(TRAIN_DATA_ROOT_DIR)

print(label_name_list)

for label_name in label_name_list:
    label_dir = os.path.join(TRAIN_DATA_ROOT_DIR, label_name)

    print('train label : ' + label_name + ' => ', len(os.listdir(os.path.join(TRAIN_DATA_ROOT_DIR, label_name))))

print('=====================================================')

# test 파일 개수 확인
label_name_list = os.listdir(TEST_DATA_ROOT_DIR)

print(label_name_list)

for label_name in label_name_list:
    label_dir = os.path.join(TEST_DATA_ROOT_DIR, label_name)

    print('test label : ' + label_name + ' => ', len(os.listdir(os.path.join(TEST_DATA_ROOT_DIR, label_name))))

print('=====================================================')
['latte', 'mocha', 'americano']
train label : latte =>  320
train label : mocha =>  244
train label : americano =>  305
=====================================================
['latte', 'mocha', 'americano']
test label : latte =>  80
test label : mocha =>  60
test label : americano =>  76
=====================================================
IMG_WIDTH = 224
IMG_HEIGHT = 224

train_datagen = ImageDataGenerator(rescale=1./255, validation_split=0.2) 

validation_data_gen = ImageDataGenerator(rescale=1./255, validation_split=0.2)

test_datagen = ImageDataGenerator(rescale=1./255)

train_generator = train_datagen.flow_from_directory(TRAIN_DATA_ROOT_DIR, batch_size=32,
                                                color_mode='rgb', class_mode='sparse',
                                                target_size=(IMG_WIDTH,IMG_HEIGHT),
                                                subset='training')

validation_generator = validation_data_gen.flow_from_directory(TRAIN_DATA_ROOT_DIR, batch_size=32,
                                                color_mode='rgb', class_mode='sparse',
                                                target_size=(IMG_WIDTH,IMG_HEIGHT),
                                                subset='validation')

test_generator = test_datagen.flow_from_directory(TEST_DATA_ROOT_DIR, batch_size=32, 
                                              color_mode='rgb', class_mode='sparse',
                                              target_size=(IMG_WIDTH,IMG_HEIGHT))
Found 696 images belonging to 3 classes.
Found 173 images belonging to 3 classes.
Found 216 images belonging to 3 classes.
print(train_generator.class_indices)
print(train_generator.num_classes)
{'americano': 0, 'latte': 1, 'mocha': 2}
3
import matplotlib.pyplot as plt

IMG_NUMS = 16

image_data, image_label = train_generator.next()

data = image_data[:IMG_NUMS]

label = image_label[:IMG_NUMS]

class_dict = {0:'americano', 1:'latte', 2:'mocha'}

plt.figure(figsize=(9,9))

for i in range(len(label)):
    plt.subplot(4, 4, i+1)
    plt.title(str(class_dict[label[i]]))
    plt.xticks([]);  plt.yticks([])

    plt.imshow(data[i])

plt.tight_layout()
plt.show()

 

모델 생성

 

import math
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Dropout, Dense, GlobalAveragePooling2D
from tensorflow.keras.applications import MobileNet 
from tensorflow.keras.models import Sequential
from tensorflow.keras.preprocessing.image import ImageDataGenerator

pre_trained_model = MobileNet(weights='imagenet', include_top=False, input_shape=(IMG_WIDTH,IMG_HEIGHT,3))

class_nums = train_generator.num_classes 

model = Sequential()

model.add(pre_trained_model)

model.add(GlobalAveragePooling2D())

model.add(Dense(64,activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(class_nums, activation='softmax'))

model.summary()
Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 mobilenet_1.00_224 (Functio  (None, 7, 7, 1024)       3228864   
 nal)                                                            
                                                                 
 global_average_pooling2d (G  (None, 1024)             0         
 lobalAveragePooling2D)                                          
                                                                 
 dense (Dense)               (None, 64)                65600     
                                                                 
 dropout (Dropout)           (None, 64)                0         
                                                                 
 dense_1 (Dense)             (None, 3)                 195       
                                                                 
=================================================================
Total params: 3,294,659
Trainable params: 3,272,771
Non-trainable params: 21,888
_________________________________________________________________
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
from datetime import datetime

model.compile(loss='sparse_categorical_crossentropy', 
            optimizer=tf.keras.optimizers.Adam(2e-5), metrics=['acc'])

save_file_name = './coffee.h5'

checkpoint = ModelCheckpoint(save_file_name,             # file명을 지정
                             monitor='val_loss',   # val_loss 값이 개선되었을때 호출
                             verbose=1,            # 로그를 출력
                             save_best_only=True,  # 가장 best 값만 저장
                             mode='auto'           # auto는 알아서 best를 찾음. min/max
                            )

earlystopping = EarlyStopping(monitor='val_loss',  # 모니터 기준 설정 (val loss) 
                              patience=5,        # 3회 Epoch동안 개선되지 않는다면 종료
                              verbose=1
                             )

start_time = datetime.now()

hist = model.fit(train_generator, epochs=20, validation_data=validation_generator)

end_time = datetime.now()

print('Elapsed Time => ', end_time-start_time)
plt.plot(hist.history['acc'], label='train')
plt.plot(hist.history['val_acc'], label='validation')
plt.title('Accuracy Trend')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(loc='best')
plt.grid()
plt.show()

plt.plot(hist.history['loss'], label='train')
plt.plot(hist.history['val_loss'], label='validation')
plt.title('Loss Trend')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(loc='best')
plt.grid()
plt.show()

 

모델 평가

 

model.evaluate(test_generator)
[0.5005858540534973, 0.8101851940155029]
728x90
반응형
LIST