본문 바로가기
Visual Intelligence/Image Segmentation

[Image Segmentation] Custom Cost (Loss) Function

by goatlab 2022. 12. 15.
728x90
반응형
SMALL

Focal Loss

 

Focal loss는 one-stage object detection에서 object와 background의 클래스간 불균형이 극도로 심한 상황을 해결하기 위해 제안되었다.

 

 

Focal Loss 작동 원리는 Focusing parameter에서 r은 일반적으로 0 ~ 5 사이의 값이다.

 

 

즉, 잘못 분류된 examples의 중요도를 상대적으로 높이는 역할을 한다. 감마 (gamma)값이 커질수록 가중치 규제가 강하게 들어간다.

 

ALPHA = 0.8
GAMMA = 2

def FocalLoss(targets, inputs, alpha=ALPHA, gamma=GAMMA):
  inputs = K.flatten(inputs)
  targets = K.flatten(targets)
  
  BCE = K.binary_crossentropy(targets, inputs)
  BCE_EXP = K.exp(-BCE)  
  focal_loss = K.mean(alpha * K.pow((1-BCE_EXP), gamma) * BCE)
  
  return focal_loss

 

Dice coefficient Loss

 

Dice loss는 이진 분류 혹은 Segmentation에 자주 사용되는 손실 함수로 정답과 예측값의 유사도를 비교하는 함수이다. 라벨링된 영역과 예측한 영역이 정확히 같다면 1, 아니면 0을 반환한다.

 

# Keras
def DiceLoss(targets, inputs, smooth=1e-6):
  # flatten label and prediction tensors
  inputs = K.flatten(inputs)
  targets = K.flatten(targets)
  
  intersection = K.sum(K.dot(targets, inputs))
  dice = (2*intersection + smooth) / (K.sum(targets) + K.sum(inputs) + smooth)
  
  return 1 - dice

 

BCE-Dice Loss

 

BCE-Dice loss는 Binary cross entropy 함수와 dice coefficient 함수를 섞어서 사용하는 손실 함수이다. 일반적으로 이미지 분할에 손실 함수로 많이 사용된다. 두 가지의 손실 함수를 가지고 손실을 평가해 더욱 효율적인 학습을 진행할 수 있다는 장점이 있다.

 

def DiceBCELoss(targets, inputs, smooth=1e-6):
  # flatten label and prediction tensors
  inputs = K.flatten(inputs)
  targets = K.flatten(targets)

  BCE = binary_crossentropy(targets, inputs)
  intersection = K.sum(K.dot(targets, inputs))
  dice_loss = 1 - (2*intersection + smooth) / (K.sum(targets) + K.sum(inputs) + smooth)
  Dice_BCE = BCE + dice_loss
  
  return Dice_BCE

 

Jaccard / Intersection over Union (IoU) Loss

 

IoU 손실 함수는 정답과 예측의 겹치는 부분의 비율을 측정하는 방법으로 객체 검출 및 이미지 분할에 자주 사용된다. 수식은 아래와 같다.

 

 

아래 그림은 위의 수식을 도식화한 것으로 관심 영역이 어디인지 정확하게 알 수 있다.

 

 

다음 그림은 IoU 점수별 정답과 예측치의 겹치는 정도를 보여준다.

 

def IoULoss(targets, inputs, smooth=1e-6):
  # flatten label and prediction tensors
  inputs = K.flatten(inputs)
  targets = K.flatten(targets)
  
  intersection = K.sum(K.dot(targets, inputs))
  total = K.sum(targets) + K.sum(inputs)
  union = total - intersection
  
  IoU = (intersection + smooth) / (union + smooth)
  
  return 1 - IoU

 

Tversky Loss

 

Tversky loss는 ‘Tversky loss function for image segmentation using 3D fully convolutional deep networks’에 소개된 손실 함수이다. 이 손실 함수는 의료 영상 분할의 클래스 불균형을 해소하기 위해 제안되었다. 아래 코드는 논문에서 제안한 하이퍼파라미터를 그대로 사용하였다.

 

Tversky 손실 함수는 alpha와 beta를 통해 가중치 규제가 이루어지며 false positive와 false negative에 더욱 큰 가중치를 준다. Alpha 와 beta 값은 본인이 구현하고자 하는 바에 따라 적절하게 최적의 값을 찾아야 한다.

 

ALPHA = 0.5
BETA = 0.5

def TverskyLoss(targets, inputs, alpha=ALPHA, beta=BETA, smooth=1e-6):
  # flatten label and prediction tensors
  inputs = K.flatten(inputs)
  targets = K.flatten(targets)
  
  # True Positives, False Positives & False Negatives
  TP = K.sum((inputs * targets))
  FP = K.sum(((1-targets) * inputs))
  FN = K.sum((targets * (1-inputs)))

  Tversky = (TP + smooth) / (TP + alpha*FP + beta*FN + smooth)
  
  return 1 - Tversky

 

Focal Tversky Loss

 

Focal Tversky 손실 함수는 위에서 설명한 focal 손실 함수와 Tversky 손실 함수를 결합한 손실 함수로 Tversky 손실 함수에 Focal 손실 함수의 gamma값 조정 능력을 추가한 형태이다.

 

ALPHA = 0.5
BETA = 0.5
GAMMA = 1

def FocalTverskyLoss(targets, inputs, alpha=ALPHA, beta=BETA, gamma=GAMMA, smooth=1e-6):
  # flatten label and prediction tensors
  inputs = K.flatten(inputs)
  targets = K.flatten(targets)
  
  # True Positives, False Positives & False Negatives
  TP = K.sum((inputs * targets))
  FP = K.sum(((1-targets) * inputs))
  FN = K.sum((targets * (1-inputs)))
  
  Tversky = (TP + smooth) / (TP + alpha*FP + beta*FN + smooth)
  FocalTversky = K.pow((1 - Tversky), gamma)
  
  return FocalTversky

 

Combo Loss

 

Combo 손실 함수는 ‘Combo loss: Handling input and outputimbalance in multi-organ segmentation’에 소개된 손실 함수이다. Combo 손실 함수는 Dice 함수와 변형된 Cross-Entropy 함수가 결합된 형태이다. Combo 손실 함수는 앞선 Tversky 손실 함수와 유사하게 false positive와 false negative에 더욱 규제를 가하는 조건을 가지고 있다.

 

ALPHA = 0.5 # < 0.5 penalises FP more, > 0.5 penalises FN more
CE_RATIO = 0.5 # weighted contribution of modified CE loss compared to Dice loss

def Combo_loss(targets, inputs):
  targets = K.flatten(targets)
  inputs = K.flatten(inputs)
  
  intersection = K.sum(targets * inputs)
  dice = (2. * intersection + smooth) / (K.sum(targets) + K.sum(inputs) + smooth)
  inputs = K.clip(inputs, e, 1.0 - e)
  out = - (ALPHA * ((targets * K.log(inputs)) + ((1 - ALPHA) * (1.0 - targets) * K.log(1.0 - inputs))))

  weighted_ce = K.mean(out, axis=-1)
  combo = (CE_RATIO * weighted_ce) - ((1 - CE_RATIO) * dice)
  
  return combo
728x90
반응형
LIST