Focal Loss
Focal loss는 one-stage object detection에서 object와 background의 클래스간 불균형이 극도로 심한 상황을 해결하기 위해 제안되었다.
Focal Loss 작동 원리는 Focusing parameter에서 r은 일반적으로 0 ~ 5 사이의 값이다.
즉, 잘못 분류된 examples의 중요도를 상대적으로 높이는 역할을 한다. 감마 (gamma)값이 커질수록 가중치 규제가 강하게 들어간다.
ALPHA = 0.8
GAMMA = 2
def FocalLoss(targets, inputs, alpha=ALPHA, gamma=GAMMA):
inputs = K.flatten(inputs)
targets = K.flatten(targets)
BCE = K.binary_crossentropy(targets, inputs)
BCE_EXP = K.exp(-BCE)
focal_loss = K.mean(alpha * K.pow((1-BCE_EXP), gamma) * BCE)
return focal_loss
Dice coefficient Loss
Dice loss는 이진 분류 혹은 Segmentation에 자주 사용되는 손실 함수로 정답과 예측값의 유사도를 비교하는 함수이다. 라벨링된 영역과 예측한 영역이 정확히 같다면 1, 아니면 0을 반환한다.
# Keras
def DiceLoss(targets, inputs, smooth=1e-6):
# flatten label and prediction tensors
inputs = K.flatten(inputs)
targets = K.flatten(targets)
intersection = K.sum(K.dot(targets, inputs))
dice = (2*intersection + smooth) / (K.sum(targets) + K.sum(inputs) + smooth)
return 1 - dice
BCE-Dice Loss
BCE-Dice loss는 Binary cross entropy 함수와 dice coefficient 함수를 섞어서 사용하는 손실 함수이다. 일반적으로 이미지 분할에 손실 함수로 많이 사용된다. 두 가지의 손실 함수를 가지고 손실을 평가해 더욱 효율적인 학습을 진행할 수 있다는 장점이 있다.
def DiceBCELoss(targets, inputs, smooth=1e-6):
# flatten label and prediction tensors
inputs = K.flatten(inputs)
targets = K.flatten(targets)
BCE = binary_crossentropy(targets, inputs)
intersection = K.sum(K.dot(targets, inputs))
dice_loss = 1 - (2*intersection + smooth) / (K.sum(targets) + K.sum(inputs) + smooth)
Dice_BCE = BCE + dice_loss
return Dice_BCE
Jaccard / Intersection over Union (IoU) Loss
IoU 손실 함수는 정답과 예측의 겹치는 부분의 비율을 측정하는 방법으로 객체 검출 및 이미지 분할에 자주 사용된다. 수식은 아래와 같다.
아래 그림은 위의 수식을 도식화한 것으로 관심 영역이 어디인지 정확하게 알 수 있다.
다음 그림은 IoU 점수별 정답과 예측치의 겹치는 정도를 보여준다.
def IoULoss(targets, inputs, smooth=1e-6):
# flatten label and prediction tensors
inputs = K.flatten(inputs)
targets = K.flatten(targets)
intersection = K.sum(K.dot(targets, inputs))
total = K.sum(targets) + K.sum(inputs)
union = total - intersection
IoU = (intersection + smooth) / (union + smooth)
return 1 - IoU
Tversky Loss
Tversky loss는 ‘Tversky loss function for image segmentation using 3D fully convolutional deep networks’에 소개된 손실 함수이다. 이 손실 함수는 의료 영상 분할의 클래스 불균형을 해소하기 위해 제안되었다. 아래 코드는 논문에서 제안한 하이퍼파라미터를 그대로 사용하였다.
Tversky 손실 함수는 alpha와 beta를 통해 가중치 규제가 이루어지며 false positive와 false negative에 더욱 큰 가중치를 준다. Alpha 와 beta 값은 본인이 구현하고자 하는 바에 따라 적절하게 최적의 값을 찾아야 한다.
ALPHA = 0.5
BETA = 0.5
def TverskyLoss(targets, inputs, alpha=ALPHA, beta=BETA, smooth=1e-6):
# flatten label and prediction tensors
inputs = K.flatten(inputs)
targets = K.flatten(targets)
# True Positives, False Positives & False Negatives
TP = K.sum((inputs * targets))
FP = K.sum(((1-targets) * inputs))
FN = K.sum((targets * (1-inputs)))
Tversky = (TP + smooth) / (TP + alpha*FP + beta*FN + smooth)
return 1 - Tversky
Focal Tversky Loss
Focal Tversky 손실 함수는 위에서 설명한 focal 손실 함수와 Tversky 손실 함수를 결합한 손실 함수로 Tversky 손실 함수에 Focal 손실 함수의 gamma값 조정 능력을 추가한 형태이다.
ALPHA = 0.5
BETA = 0.5
GAMMA = 1
def FocalTverskyLoss(targets, inputs, alpha=ALPHA, beta=BETA, gamma=GAMMA, smooth=1e-6):
# flatten label and prediction tensors
inputs = K.flatten(inputs)
targets = K.flatten(targets)
# True Positives, False Positives & False Negatives
TP = K.sum((inputs * targets))
FP = K.sum(((1-targets) * inputs))
FN = K.sum((targets * (1-inputs)))
Tversky = (TP + smooth) / (TP + alpha*FP + beta*FN + smooth)
FocalTversky = K.pow((1 - Tversky), gamma)
return FocalTversky
Combo Loss
Combo 손실 함수는 ‘Combo loss: Handling input and outputimbalance in multi-organ segmentation’에 소개된 손실 함수이다. Combo 손실 함수는 Dice 함수와 변형된 Cross-Entropy 함수가 결합된 형태이다. Combo 손실 함수는 앞선 Tversky 손실 함수와 유사하게 false positive와 false negative에 더욱 규제를 가하는 조건을 가지고 있다.
ALPHA = 0.5 # < 0.5 penalises FP more, > 0.5 penalises FN more
CE_RATIO = 0.5 # weighted contribution of modified CE loss compared to Dice loss
def Combo_loss(targets, inputs):
targets = K.flatten(targets)
inputs = K.flatten(inputs)
intersection = K.sum(targets * inputs)
dice = (2. * intersection + smooth) / (K.sum(targets) + K.sum(inputs) + smooth)
inputs = K.clip(inputs, e, 1.0 - e)
out = - (ALPHA * ((targets * K.log(inputs)) + ((1 - ALPHA) * (1.0 - targets) * K.log(1.0 - inputs))))
weighted_ce = K.mean(out, axis=-1)
combo = (CE_RATIO * weighted_ce) - ((1 - CE_RATIO) * dice)
return combo
'Visual Intelligence > Image Segmentation' 카테고리의 다른 글
[Image Segmentation] 주목 메커니즘 (Attention Module) (0) | 2022.12.15 |
---|---|
[Image Segmentation] U-Net (2) (0) | 2022.12.15 |
[Image Segmentation] U-Net (1) (0) | 2022.12.15 |
이미지 분할 (Image Segmentation) (0) | 2022.12.15 |