The project is many 1000s of lines long. It fails in the model.fit TF command. The only thing different from other versions which work is the loss function - which I share below. The relevant class is BoundaryWithCategoricalDiceLoss which is called with boundary_loss_type = "GRAD" . When I use the loss with boundary_loss_type = "MSE" all works fine. This class is a subclass of CategoricalDiceLoss which is a sub-class of keras.losses.Loss
from typing import Dict, Iterable, Any, Optional, Union, Literal
import tensorflow as tfimport keras
from src.utils.arrays.spatial import get_segmentation_boundary
class CategoricalDiceLoss(keras.losses.Loss):def __init__(self, include_background: bool = False, from_logits: bool = True,class_weights: Union[Literal['uniform', 'linear', 'square'], Dict[Any, float]] = 'uniform',name: str = 'categorical_dice', **kwargs):super().__init__(name=name, **kwargs)
self._include_background = include_backgroundself._from_logits = from_logitsif type(class_weights) is dict:self._class_weights = [w for w in class_weights.values()]else:self._class_weights = class_weights
def call(self, y_true, y_pred):## If the input is added to the output, we need to remove it before calculating the lossif y_pred.shape[-1] > y_true.shape[-1]:y_pred = y_pred[...,:-1]if self._from_logits:y_pred = tf.nn.softmax(y_pred, axis=-1)
num_classes = tf.shape(y_true)[-1]batch_size = tf.shape(y_true)[0]y_true = tf.reshape(y_true, [batch_size, -1, num_classes])y_pred = tf.reshape(y_pred, [batch_size, -1, num_classes])match self._class_weights:case "uniform":class_weights = tf.tile(tf.constant([1.0], tf.float32), tf.reshape(num_classes, [-1]))if not self._include_background:class_weights = tf.concat([[0.0], class_weights[1:]], axis=0)class_weights = class_weights / tf.reduce_sum(class_weights)case "linear":raise Exception("Not implemented yet")# class_weights = tf.reduce_sum(y_true, axis=1)# if not self._include_background:# class_weights = tf.concat([tf.zeros((batch_size, 1)), class_weights[:,1:]], axis=-1)# class_weights = class_weights / tf.reduce_sum(class_weights)case "square":raise Exception("Not implemented yet")case _:class_weights = self._class_weightsclass_weights = class_weights / tf.reduce_sum(class_weights)
products = tf.reduce_sum(y_true * y_pred, axis=1)sums = tf.reduce_sum(y_true, axis=1) + tf.reduce_sum(y_pred, axis=1)dices = (2.0 * products) / (sums + 1e-8)### The term (tf.reduce_sum(y_true, axis = 1) > 0) is used to ignore classes that are not present in the ground truthweighted_loss = tf.reduce_sum((1-dices) *tf.cast((tf.reduce_sum(y_true, axis = 1) > 0), tf.float32) *class_weights, axis=-1)
return weighted_loss
class BoundaryWithCategoricalDiceLoss(CategoricalDiceLoss):## CR David - Support for class_weights not implemented yetdef __init__(self, thickness: Optional[int]=2, alpha: float = 0.5,boundary_loss_type: Literal["CE", "MSE", "GRAD"] = "CE",boundary_channels: Optional[Iterable[int]] = None, **kwargs):super().__init__(**kwargs)self.thickness = thicknessself.alpha = alphaself.boundary_channels = None if boundary_channels is None else list(boundary_channels)self.boundary_loss_type = boundary_loss_typeif self.boundary_loss_type == "GRAD":self.sobel_x = tf.constant([[-1, 0, 1],[-2, 0, 2],[-1, 0, 1]], dtype=tf.float32)
self.sobel_y = tf.constant([[-1, -2, -1],[ 0, 0, 0],[ 1, 2, 1]], dtype=tf.float32)
# Reshape filtersself.sobel_x = tf.reshape(self.sobel_x, [3, 3, 1, 1])self.sobel_y = tf.reshape(self.sobel_y, [3, 3, 1, 1])
def _get_boundary(self, y_true):if self.boundary_channels is None:y_true_all_labels = tf.cast(tf.expand_dims(tf.reduce_any(y_true[...,1:] > 0.0, axis=-1), axis=-1), tf.float32)else:y_true_all_labels = tf.cast(tf.expand_dims(tf.reduce_any(tf.gather(y_true, self.boundary_channels, axis=-1) > 0.0, axis=-1), axis=-1), tf.float32)return get_segmentation_boundary(y_true_all_labels, thickness=self.thickness)
def _calc_gradient(self, input):dx = tf.nn.conv2d(input, self.sobel_x, strides=[1, 1, 1, 1], padding='SAME')dy = tf.nn.conv2d(input, self.sobel_y, strides=[1, 1, 1, 1], padding='SAME')grad = tf.sqrt(tf.square(dx) + tf.square(dy))grad = ((grad - tf.reduce_min(grad, axis=[1,2], keepdims=True)) + 1e-8) / ((tf.reduce_max(grad) - tf.reduce_min(grad, axis=[1,2], keepdims=True)) + 1e-8)return grad
def _CE_loss(self, y_true, y_pred, roi_mask):loss = tf.keras.losses.CategoricalCrossentropy(reduction=None)(y_true, y_pred)masked_loss = tf.reduce_mean(loss * roi_mask, axis=[1, 2])return masked_loss
def _GRAD_loss(self, x, y_pred, roi_mask):x_grad = self._calc_gradient(x)y_pred_grad = self._calc_gradient(y_pred)loss = tf.squeeze(tf.sqrt(tf.square(x_grad - y_pred_grad) + 1e-8), axis=-1)masked_loss = tf.reduce_mean(loss * roi_mask, axis=[1, 2])return masked_loss
def _MSE_loss(self, y_true, y_pred, roi_mask):loss = tf.reduce_sum(tf.sqrt(tf.square(y_true - y_pred) + 1e-8), axis = -1)masked_loss = tf.reduce_mean(loss * roi_mask, axis=[1, 2])return masked_loss
def call(self, y_true, y_pred):## If the input is added to the output, we need to remove it before calculating the lossif y_pred.shape[-1] > y_true.shape[-1]:if self.boundary_loss_type == "GRAD":x = y_pred[...,-1:]if self.boundary_channels is not None:y_pred_for_grad = tf.reduce_sum(tf.gather(y_pred, self.boundary_channels, axis=-1), axis=-1, keepdims=True)else:y_pred_for_grad = tf.reduce_sum(y_pred[...,1:-1], axis=-1, keepdims=True)y_pred = y_pred[...,:-1]
dice_loss = super().call(y_true, y_pred)if self._from_logits:y_pred = tf.nn.softmax(y_pred, axis=-1)
# y true is sometime in the 5 rank meaning the shape is [batch_size, 1, 256, 256, 2]# to handle cmpatibility with the boundary function we need to remove the 1 rank if it exists:if len(y_true.shape) == 5:y_true = tf.squeeze(y_true, axis=1)if len(y_pred.shape) == 5:y_pred = tf.squeeze(y_pred, axis=1)boundary_mask = tf.squeeze(self._get_boundary(y_true), axis=-1)
match self.boundary_loss_type:case "CE":boundary_loss = self._CE_loss(y_true, y_pred, boundary_mask)case "MSE":boundary_loss = self._MSE_loss(y_true, y_pred, boundary_mask)case "GRAD":boundary_loss = self._GRAD_loss(x, y_pred_for_grad, boundary_mask)
#print(f"Boundary loss: {boundary_loss}")losses_vector = self.alpha * dice_loss + (1-self.alpha) * boundary_loss
return losses_vector