The project is many 1000s of lines long. It fails in the model.fit TF command. The only thing different from other versions which work is the loss function - which I share below. The relevant class is BoundaryWithCategoricalDiceLoss
which is called with boundary_loss_type = "GRAD"
. When I use the loss with boundary_loss_type = "MSE"
all works fine. This class is a subclass of CategoricalDiceLoss
which is a sub-class of keras.losses.Loss
from typing import Dict, Iterable, Any, Optional, Union, Literal
import tensorflow as tf
import keras
from src.utils.arrays.spatial import get_segmentation_boundary
class CategoricalDiceLoss(keras.losses.Loss):
def __init__(self, include_background: bool = False, from_logits: bool = True,
class_weights: Union[Literal['uniform', 'linear', 'square'], Dict[Any, float]] = 'uniform',
name: str = 'categorical_dice', **kwargs):
super().__init__(name=name, **kwargs)
self._include_background = include_background
self._from_logits = from_logits
if type(class_weights) is dict:
self._class_weights = [w for w in class_weights.values()]
else:
self._class_weights = class_weights
def call(self, y_true, y_pred):
## If the input is added to the output, we need to remove it before calculating the loss
if y_pred.shape[-1] > y_true.shape[-1]:
y_pred = y_pred[...,:-1]
if self._from_logits:
y_pred = tf.nn.softmax(y_pred, axis=-1)
num_classes = tf.shape(y_true)[-1]
batch_size = tf.shape(y_true)[0]
y_true = tf.reshape(y_true, [batch_size, -1, num_classes])
y_pred = tf.reshape(y_pred, [batch_size, -1, num_classes])
match self._class_weights:
case "uniform":
class_weights = tf.tile(tf.constant([1.0], tf.float32), tf.reshape(num_classes, [-1]))
if not self._include_background:
class_weights = tf.concat([[0.0], class_weights[1:]], axis=0)
class_weights = class_weights / tf.reduce_sum(class_weights)
case "linear":
raise Exception("Not implemented yet")
# class_weights = tf.reduce_sum(y_true, axis=1)
# if not self._include_background:
# class_weights = tf.concat([tf.zeros((batch_size, 1)), class_weights[:,1:]], axis=-1)
# class_weights = class_weights / tf.reduce_sum(class_weights)
case "square":
raise Exception("Not implemented yet")
case _:
class_weights = self._class_weights
class_weights = class_weights / tf.reduce_sum(class_weights)
products = tf.reduce_sum(y_true * y_pred, axis=1)
sums = tf.reduce_sum(y_true, axis=1) + tf.reduce_sum(y_pred, axis=1)
dices = (2.0 * products) / (sums + 1e-8)
### The term (tf.reduce_sum(y_true, axis = 1) > 0) is used to ignore classes that are not present in the ground truth
weighted_loss = tf.reduce_sum((1-dices) *
tf.cast((tf.reduce_sum(y_true, axis = 1) > 0), tf.float32) *
class_weights, axis=-1)
return weighted_loss
class BoundaryWithCategoricalDiceLoss(CategoricalDiceLoss):
## CR David - Support for class_weights not implemented yet
def __init__(self, thickness: Optional[int]=2, alpha: float = 0.5,
boundary_loss_type: Literal["CE", "MSE", "GRAD"] = "CE",
boundary_channels: Optional[Iterable[int]] = None, **kwargs):
super().__init__(**kwargs)
self.thickness = thickness
self.alpha = alpha
self.boundary_channels = None if boundary_channels is None else list(boundary_channels)
self.boundary_loss_type = boundary_loss_type
if self.boundary_loss_type == "GRAD":
self.sobel_x = tf.constant([[-1, 0, 1],
[-2, 0, 2],
[-1, 0, 1]], dtype=tf.float32)
self.sobel_y = tf.constant([[-1, -2, -1],
[ 0, 0, 0],
[ 1, 2, 1]], dtype=tf.float32)
# Reshape filters
self.sobel_x = tf.reshape(self.sobel_x, [3, 3, 1, 1])
self.sobel_y = tf.reshape(self.sobel_y, [3, 3, 1, 1])
def _get_boundary(self, y_true):
if self.boundary_channels is None:
y_true_all_labels = tf.cast(tf.expand_dims(tf.reduce_any(y_true[...,1:] > 0.0, axis=-1), axis=-1), tf.float32)
else:
y_true_all_labels = tf.cast(tf.expand_dims(tf.reduce_any(tf.gather(y_true, self.boundary_channels, axis=-1) > 0.0, axis=-1), axis=-1), tf.float32)
return get_segmentation_boundary(y_true_all_labels, thickness=self.thickness)
def _calc_gradient(self, input):
dx = tf.nn.conv2d(input, self.sobel_x, strides=[1, 1, 1, 1], padding='SAME')
dy = tf.nn.conv2d(input, self.sobel_y, strides=[1, 1, 1, 1], padding='SAME')
grad = tf.sqrt(tf.square(dx) + tf.square(dy))
grad = ((grad - tf.reduce_min(grad, axis=[1,2], keepdims=True)) + 1e-8) / ((tf.reduce_max(grad) - tf.reduce_min(grad, axis=[1,2], keepdims=True)) + 1e-8)
return grad
def _CE_loss(self, y_true, y_pred, roi_mask):
loss = tf.keras.losses.CategoricalCrossentropy(reduction=None)(y_true, y_pred)
masked_loss = tf.reduce_mean(loss * roi_mask, axis=[1, 2])
return masked_loss
def _GRAD_loss(self, x, y_pred, roi_mask):
x_grad = self._calc_gradient(x)
y_pred_grad = self._calc_gradient(y_pred)
loss = tf.squeeze(tf.sqrt(tf.square(x_grad - y_pred_grad) + 1e-8), axis=-1)
masked_loss = tf.reduce_mean(loss * roi_mask, axis=[1, 2])
return masked_loss
def _MSE_loss(self, y_true, y_pred, roi_mask):
loss = tf.reduce_sum(tf.sqrt(tf.square(y_true - y_pred) + 1e-8), axis = -1)
masked_loss = tf.reduce_mean(loss * roi_mask, axis=[1, 2])
return masked_loss
def call(self, y_true, y_pred):
## If the input is added to the output, we need to remove it before calculating the loss
if y_pred.shape[-1] > y_true.shape[-1]:
if self.boundary_loss_type == "GRAD":
x = y_pred[...,-1:]
if self.boundary_channels is not None:
y_pred_for_grad = tf.reduce_sum(tf.gather(y_pred, self.boundary_channels, axis=-1), axis=-1, keepdims=True)
else:
y_pred_for_grad = tf.reduce_sum(y_pred[...,1:-1], axis=-1, keepdims=True)
y_pred = y_pred[...,:-1]
dice_loss = super().call(y_true, y_pred)
if self._from_logits:
y_pred = tf.nn.softmax(y_pred, axis=-1)
# y true is sometime in the 5 rank meaning the shape is [batch_size, 1, 256, 256, 2]
# to handle cmpatibility with the boundary function we need to remove the 1 rank if it exists:
if len(y_true.shape) == 5:
y_true = tf.squeeze(y_true, axis=1)
if len(y_pred.shape) == 5:
y_pred = tf.squeeze(y_pred, axis=1)
boundary_mask = tf.squeeze(self._get_boundary(y_true), axis=-1)
match self.boundary_loss_type:
case "CE":
boundary_loss = self._CE_loss(y_true, y_pred, boundary_mask)
case "MSE":
boundary_loss = self._MSE_loss(y_true, y_pred, boundary_mask)
case "GRAD":
boundary_loss = self._GRAD_loss(x, y_pred_for_grad, boundary_mask)
#print(f"Boundary loss: {boundary_loss}")
losses_vector = self.alpha * dice_loss + (1-self.alpha) * boundary_loss
return losses_vector