본문 바로가기
Visual Intelligence/Image Segmentation

[Image Segmentation] U-Net (1)

by goatlab 2022. 12. 15.
728x90
반응형
SMALL

U-Net

 

U-Net은 의료영상 분할 분야에서 처음 제안된 네트워크로 End-to-End 방식이다. U-Net은 FCN (Fully Convolutional Network)을 기반으로 구성되어 있으며 의료 이미지 특상 적은 데이터를 가지고도 더욱 정확한 Segmentation을 하기 위해 일부 수정을 했다.

 

 

각각의 파란색 박스는 멀티 채널 피쳐맵을 의미한다. 채널 수는 박스위에 표시되어 있다. U-Net은 가운데를 기준으로 왼쪽을 Contracting path, 오른쪽을 Expanding path라 부른다. Contracting path는 이미지의 문맥 (context)를 추출할 수 있도록 도와주는 역할을 한다. Expanding path는 피쳐맵을 업 샘플링하고 이를 contracting path에서 추출한 피쳐맵과 결합하여 더욱 정확한 Localization을 수행한다.

 

 

U-Net은 3 × 3 컨볼루션이 주를 이루고 있으며 각 블록은 2개의 3 × 3 컨볼루션 레이어로 이루어져 있다. Contracting path는 4개의 블록으로 이루어져 있으며 각 블록은 Maxpool을 이용해 사이즈를 줄이면서 다음 블록으로 내려간다. 반면, Expanding path는 컨볼루션 블록에 up-conv 레이어를 붙인 형태이다. 즉, contracting 과정에서 줄어든 피쳐맵의 사이즈를 다시 키워가는 과정이다. 또한, expanding 과정에서 얻어진 피쳐맵과 contracting 과정에서 얻어진 피쳐맵을 concatenate하여 사용한다. 마지막 레이어는 1 × 1 컨볼루션 연산을 사용해 피쳐들을 정리하고 최종 output으로 매핑하게 된다.

 

U-Net에서는 인코더-디코더 구조에 스킵 커넥션 (skip-connection)을 추가했다. 영상 크기를 줄였다가 다시 키우면서 정교한 픽셀 정보가 사라지게 된다. 이는 픽셀 단위로 정밀한 예측이 필요한 Segmentation에서 치명적인 문제에 해당한다. 이를 해결하기 위해 인코더에서 추출한 중요한 정보를 디코더에 바로 넘겨주기 위해 스킵 커넥션을 사용한 것이다. 이를 통해 디코더 부분에서 더욱 선명한 이미지 결과를 얻게 되어 기존의 네트워크들보다 정확한 예측이 가능해졌다.

 

라이브러리

 

import numpy as np
import matplotlib.pyplot as plt
import os
from PIL import Image
import keras
from keras.models import Model
from keras.layers import Conv2D, MaxPooling2D, Input, Conv2DTranspose, Concatenate, BatchNormalization, UpSampling2D
from keras.layers import Dropout, Activation
from keras.optimizers import Adam, SGD
from keras.layers import ELU, PReLU, LeakyReLU
from keras.callbacks import ModelCheckpoint, ReduceLROnPlateau, EarlyStopping
from keras import backend as K
from keras.utils import plot_model
import tensorflow as tf
import glob
import random
import cv2
from random import shuffle
import math

!wget http://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz
!wget http://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz
!tar -xvzf images.tar.gz && tar -xvzf annotations.tar.gz
!rm images/*.mat

 

데이터 전처리

 

def image_generator(files, batch_size = 32, sz = (256, 256)):
  while True:
    # extract a random batch
    batch = np.random.choice(files, size = batch_size)
    
    # variables for collecting batches of inputs and outputs
    batch_x = []
    batch_y = []
    
    for f in batch:
      # get the masks. Note that masks are png files
      mask = Image.open(f'annotations/trimaps/{f[:-4]}.png')
      mask = np.array(mask.resize(sz))
      
      # preprocess the mask
      mask[mask >= 2] = 0
      mask[mask != 0 ] = 1
      
      batch_y.append(mask)
      
      # preprocess the raw images
      raw = Image.open(f'images/{f}')
      raw = raw.resize(sz)
      raw = np.array(raw)
      # check the number of channels because some of the images are RGBA or GRAY
      
      if len(raw.shape) == 2:
        raw = np.stack((raw,)*3, axis=-1)
        
      else:
        raw = raw[:,:,0:3]
      
      batch_x.append(raw)
      
    # preprocess a batch of images and masks
    batch_x = np.array(batch_x)/255.
    batch_y = np.array(batch_y)
    batch_y = np.expand_dims(batch_y,3)

    yield (batch_x, batch_y)

batch_size = 32

all_files = os.listdir('images')
shuffle(all_files)
split = int(0.95 * len(all_files))

# split into training and testing
train_files = all_files[0:split]
test_files = all_files[split:]
train_generator = image_generator(train_files, batch_size = batch_size)
test_generator = image_generator(test_files, batch_size = batch_size)
x, y= next(train_generator)
plt.axis('off')
img = x[0]
msk = y[0].squeeze()
msk = np.stack((msk,)*3, axis=-1)
plt.imshow(np.concatenate([img, msk, img*msk], axis = 1))

 

모델의 성능을 측정하기 위한 지표로 IoU (Intersection of Union) 메트릭을 구현해 본다.

 

def mean_iou(y_true, y_pred):
  yt0 = y_true[:,:,:,0]
  yp0 = K.cast(y_pred[:,:,:,0] > 0.5, 'float32')
  inter = tf.math.count_nonzero(tf.logical_and(tf.equal(yt0, 1), tf.equal(yp0, 1)))
  union = tf.math.count_nonzero(tf.add(yt0, yp0))
  iou = tf.where(tf.equal(union, 0), 1., tf.cast(inter/union, 'float32'))
  
  return iou

 

U-Net 모델

 

def unet(sz = (256, 256, 3)):
  x = Input(sz)
  inputs = x
  
  #down sampling 
  f = 8
  layers = []
  
  for i in range(0, 6):
    x = Conv2D(f, 3, activation='relu', padding='same') (x)
    x = Conv2D(f, 3, activation='relu', padding='same') (x)
    layers.append(x)
    x = MaxPooling2D() (x)
    f = f*2
  ff2 = 64 
  
  #bottleneck 
  j = len(layers) - 1
  x = Conv2D(f, 3, activation='tanh', padding='same') (x)
  x = Conv2D(f, 3, activation='tanh', padding='same') (x)
  x = Conv2DTranspose(ff2, 2, strides=(2, 2), padding='same') (x)
  x = Concatenate(axis=3)([x, layers[j]])
  j = j -1 
  
  #upsampling 
  for i in range(0, 5):
    ff2 = ff2//2
    f = f // 2 
    x = Conv2D(f, 3, activation='tanh', padding='same') (x)
    x = Conv2D(f, 3, activation='tanh', padding='same') (x)
    x = Conv2DTranspose(ff2, 2, strides=(2, 2), padding='same') (x)
    x = Concatenate(axis=3)([x, layers[j]])
    j = j -1 
    
  
  #classification 
  x = Conv2D(f, 3, activation='tanh', padding='same') (x)
  x = Conv2D(f, 3, activation='tanh', padding='same') (x)
  outputs = Conv2D(1, 1, activation='sigmoid') (x)
  
  #model creation 
  model = Model(inputs=[inputs], outputs=[outputs])
  model.compile(optimizer = 'rmsprop', loss = 'binary_crossentropy', metrics = [mean_iou])
  
  return model

model = unet()

 

모델 정의가 완료되었으면 학습 중간중간 모델을 저장하고 학습 진행상황을 확인하기 위한 callback 함수를 구현한다.

 

def build_callbacks():
  checkpointer = ModelCheckpoint(filepath='unet.h5', verbose=0, save_best_only=True, save_weights_only=True)
  callbacks = [checkpointer, PlotLearning()]
  
  return callbacks
  
# inheritance for training process plot
class PlotLearning(keras.callbacks.Callback):
  def on_train_begin(self, logs={}):
    self.i = 0
    self.x = []
    self.losses = []
    self.val_losses = []
    self.acc = []
    self.val_acc = []
    
    # self.fig = plt.figure()
    self.logs = []
  
  def on_epoch_end(self, epoch, logs={}):
    self.logs.append(logs)
    self.x.append(self.i)
    self.losses.append(logs.get('loss'))
    self.val_losses.append(logs.get('val_loss'))
    self.acc.append(logs.get('mean_iou'))
    self.val_acc.append(logs.get('val_mean_iou'))
    self.i += 1
    
    print('i=',self.i,'loss=',logs.get('loss'),'val_loss=',logs.get('val_loss'),'mean_iou=',logs.get('mean_iou'),'val_mean_iou=',logs.get('val_mean_iou'))
    
    # choose a random test image and preprocess
    path = np.random.choice(test_files)
    raw = Image.open(f'images/{path}')
    raw = np.array(raw.resize((256, 256)))/255.
    raw = raw[:,:,0:3]
    
    # predict the mask
    pred = model.predict(np.expand_dims(raw, 0))
    
    # mask post-processing
    msk = pred.squeeze()
    msk = np.stack((msk,)*3, axis=-1)
    msk[msk >= 0.5] = 1
    msk[msk < 0.5] = 0
    
    # show the mask and the segmented image
    combined = np.concatenate([raw, msk, raw* msk], axis = 1)
    plt.axis('off')
    plt.imshow(combined)
    plt.show()

 

모델 학습

 

train_steps = len(train_files) //batch_size
test_steps = len(test_files) //batch_size

model.fit_generator(train_generator, epochs = 20, steps_per_epoch = train_steps,validation_data = test_generator, validation_steps = test_steps, callbacks = build_callbacks(), verbose = 1)
!wget http://r.ddmcdn.com/s_f/o_1/cx_462/cy_245/cw_1349/ch_1349/w_720/APL/uploads/2015/06/caturday-shutterstock_149320799.jpg -O test.jpg

raw = Image.open('test.jpg')
raw = np.array(raw.resize((256, 256)))/255.
raw = raw[:,:,0:3]

# predict the mask
pred = model.predict(np.expand_dims(raw, 0))

# mask post-processing
msk = pred.squeeze()
msk = np.stack((msk,)*3, axis=-1)
msk[msk >= 0.5] = 1
msk[msk < 0.5] = 0

# show the mask and the segmented image
combined = np.concatenate([raw, msk, raw* msk], axis = 1)
plt.axis('off')
plt.imshow(combined)
plt.show()

 

728x90
반응형
LIST