Unverified Commit 2eb584a9 authored by Pavel Yakubovskiy's avatar Pavel Yakubovskiy Committed by GitHub

Add `beta` to dice; Imporve losses docs (#103)

parent 2ee83240
......@@ -36,12 +36,44 @@ def jaccard_loss(gt, pr, class_weights=1., smooth=SMOOTH, per_image=True):
def bce_jaccard_loss(gt, pr, bce_weight=1., smooth=SMOOTH, per_image=True):
r"""Sum of binary crossentropy and jaccard losses:
.. math:: L(A, B) = bce_weight * binary_crossentropy(A, B) + jaccard_loss(A, B)
Args:
gt: ground truth 4D keras tensor (B, H, W, C)
pr: prediction 4D keras tensor (B, H, W, C)
class_weights: 1. or list of class weights for jaccard loss, len(weights) = C
smooth: value to avoid division by zero
per_image: if ``True``, jaccard loss is calculated as mean over images in batch (B),
else over whole batch (only for jaccard loss)
Returns:
loss
"""
bce = K.mean(binary_crossentropy(gt, pr))
loss = bce_weight * bce + jaccard_loss(gt, pr, smooth=smooth, per_image=per_image)
return loss
def cce_jaccard_loss(gt, pr, cce_weight=1., class_weights=1., smooth=SMOOTH, per_image=True):
r"""Sum of categorical crossentropy and jaccard losses:
.. math:: L(A, B) = cce_weight * categorical_crossentropy(A, B) + jaccard_loss(A, B)
Args:
gt: ground truth 4D keras tensor (B, H, W, C)
pr: prediction 4D keras tensor (B, H, W, C)
class_weights: 1. or list of class weights for jaccard loss, len(weights) = C
smooth: value to avoid division by zero
per_image: if ``True``, jaccard loss is calculated as mean over images in batch (B),
else over whole batch
Returns:
loss
"""
cce = categorical_crossentropy(gt, pr) * class_weights
cce = K.mean(cce)
return cce_weight * cce + jaccard_loss(gt, pr, smooth=smooth, class_weights=class_weights, per_image=per_image)
......@@ -57,7 +89,7 @@ get_custom_objects().update({
# ============================== Dice Losses ================================
def dice_loss(gt, pr, class_weights=1., smooth=SMOOTH, per_image=True):
def dice_loss(gt, pr, class_weights=1., smooth=SMOOTH, per_image=True, beta=1.):
r"""Dice loss function for imbalanced datasets:
.. math:: L(precision, recall) = 1 - (1 + \beta^2) \frac{precision \cdot recall}
......@@ -70,24 +102,59 @@ def dice_loss(gt, pr, class_weights=1., smooth=SMOOTH, per_image=True):
smooth: value to avoid division by zero
per_image: if ``True``, metric is calculated as mean over images in batch (B),
else over whole batch
beta: coefficient for precision recall balance
Returns:
Dice loss in range [0, 1]
"""
return 1 - f_score(gt, pr, class_weights=class_weights, smooth=smooth, per_image=per_image, beta=1.)
return 1 - f_score(gt, pr, class_weights=class_weights, smooth=smooth, per_image=per_image, beta=beta)
def bce_dice_loss(gt, pr, bce_weight=1., smooth=SMOOTH, per_image=True):
def bce_dice_loss(gt, pr, bce_weight=1., smooth=SMOOTH, per_image=True, beta=1.):
r"""Sum of binary crossentropy and dice losses:
.. math:: L(A, B) = bce_weight * binary_crossentropy(A, B) + dice_loss(A, B)
Args:
gt: ground truth 4D keras tensor (B, H, W, C)
pr: prediction 4D keras tensor (B, H, W, C)
class_weights: 1. or list of class weights for dice loss, len(weights) = C
smooth: value to avoid division by zero
per_image: if ``True``, dice loss is calculated as mean over images in batch (B),
else over whole batch
beta: coefficient for precision recall balance
Returns:
loss
"""
bce = K.mean(binary_crossentropy(gt, pr))
loss = bce_weight * bce + dice_loss(gt, pr, smooth=smooth, per_image=per_image)
loss = bce_weight * bce + dice_loss(gt, pr, smooth=smooth, per_image=per_image, beta=beta)
return loss
def cce_dice_loss(gt, pr, cce_weight=1., class_weights=1., smooth=SMOOTH, per_image=True):
def cce_dice_loss(gt, pr, cce_weight=1., class_weights=1., smooth=SMOOTH, per_image=True, beta=1.):
r"""Sum of categorical crossentropy and dice losses:
.. math:: L(A, B) = cce_weight * categorical_crossentropy(A, B) + dice_loss(A, B)
Args:
gt: ground truth 4D keras tensor (B, H, W, C)
pr: prediction 4D keras tensor (B, H, W, C)
class_weights: 1. or list of class weights for dice loss, len(weights) = C
smooth: value to avoid division by zero
per_image: if ``True``, dice loss is calculated as mean over images in batch (B),
else over whole batch
beta: coefficient for precision recall balance
Returns:
loss
"""
cce = categorical_crossentropy(gt, pr) * class_weights
cce = K.mean(cce)
return cce_weight * cce + dice_loss(gt, pr, smooth=smooth, class_weights=class_weights, per_image=per_image)
return cce_weight * cce + dice_loss(gt, pr, smooth=smooth, class_weights=class_weights, per_image=per_image, beta=beta)
# Update custom objects
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment