And also a log would be interesting to look at 🙂
This is what I get when running on Clearml. Notice the nan in the loss
Epoch 1/150
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1739804333.538008 890492 service.cc:145] XLA service 0x7f19b80029d0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1739804333.538068 890492 service.cc:153] StreamExecutor device (0): NVIDIA GeForce RTX 2080 Ti, Compute Capability 7.5
2025-02-17 14:58:54.021777: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:268] disabling MLIR crash reproducer, set env var MLIR_CRASH_REPRODUCER_DIRECTORY
to enable.
2025-02-17 14:58:54.823385: I external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:465] Loaded cuDNN version 8906
I0000 00:00:1739804355.312793 890492 device_compiler.h:188] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process.
ClearML Monitor: Could not detect iteration reporting, falling back to iterations as seconds-from-start
2025-02-17 15:02:43,896 - clearml - INFO - NaN value encountered. Reporting it as '0.0'. Use clearml.Logger.set_reporting_nan_value to assign another value
Saved artifact at '/tmp/tmpmfu4qumg/model'. The following endpoints are available:
- Endpoint 'serve'
args_0 (POSITIONAL_ONLY): TensorSpec(shape=(None, 1, 256, 256, 1), dtype=tf.float32, name='keras_tensor_51')
Output Type:
TensorSpec(shape=(None, 256, 256, 3), dtype=tf.float32, name=None)
Captures:
139737631767920: TensorSpec(shape=(1, 1, 1, 1, 1), dtype=tf.float32, name=None)
139737631767040: TensorSpec(shape=(1, 1, 1, 1, 1), dtype=tf.float32, name=None)
139737631768272: TensorSpec(shape=(1, 1, 256, 256, 1), dtype=tf.float32, name=None)
139753012177808: TensorSpec(shape=(), dtype=tf.resource, name=None)
139753012178864: TensorSpec(shape=(), dtype=tf.resource, name=None)
139753012176224: TensorSpec(shape=(), dtype=tf.resource, name=None)
139753012180624: TensorSpec(shape=(), dtype=tf.resource, name=None)
139753012178512: TensorSpec(shape=(), dtype=tf.resource, name=None)
139753012179744: TensorSpec(shape=(), dtype=tf.resource, name=None)
139753054298080: TensorSpec(shape=(), dtype=tf.resource, name=None)
139753012174288: TensorSpec(shape=(), dtype=tf.resource, name=None)
139753012176400: TensorSpec(shape=(), dtype=tf.resource, name=None)
139753012181328: TensorSpec(shape=(), dtype=tf.resource, name=None)
139753012180976: TensorSpec(shape=(), dtype=tf.resource, name=None)
139753012180800: TensorSpec(shape=(), dtype=tf.resource, name=None)
139753012182208: TensorSpec(shape=(), dtype=tf.resource, name=None)
139753012182560: TensorSpec(shape=(), dtype=tf.resource, name=None)
139753012181680: TensorSpec(shape=(), dtype=tf.resource, name=None)
139753012181504: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737633002192: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737633002368: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737633003424: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737633003600: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737633004832: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737633005184: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737633003952: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737633003776: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737633006064: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737633006240: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737633006768: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737633006592: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737633008000: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737633008352: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737633007472: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737633007296: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737633009056: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737633009232: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737633009760: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737633009584: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737633009408: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737633011168: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737633008704: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737633010464: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737633012400: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737633012576: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737633012928: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737633013104: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737631753312: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737631753664: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737631752432: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737631752256: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737631754720: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737631754896: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737631755424: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737631755248: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737631755072: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737631756832: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737631754544: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737631756128: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737631760704: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737631760880: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737631761408: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737631761232: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737631761056: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737631762816: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737631760528: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737631762112: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737631763872: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737631764048: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737631764576: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737631764400: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737631764224: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737631765984: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737631763696: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737631765280: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737631766688: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737631766864: TensorSpec(shape=(), dtype=tf.resource, name=None)
2025-02-17 15:02:46,852 - clearml.Task - INFO - Completed model upload to None
1939/1939 - 242s - 125ms/step - dice_Bg: nan - dice_Lumen: nan - dice_all: nan - loss: nan - val_dice_Bg: nan - val_dice_Lumen: nan - val_dice_all: nan - val_loss: nan
Epoch 2/150
The code that generates this is the fit method in TFmodel.fit(train_dataset, validation_data=val_dataset, epochs=cfg.fit.epochs, callbacks=callbacks, verbose=2)
Clearml is activated in the usual way:task = Task.init(project_name=project_name, task_name=name, output_uri=True, auto_connect_frameworks={'tensorflow': False}, **kwargs)
This is what I get using clearmlThis is what I get when running on Clearml. Notice the nan in the loss
Epoch 1/150
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1739804333.538008 890492 service.cc:145] XLA service 0x7f19b80029d0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1739804333.538068 890492 service.cc:153] StreamExecutor device (0): NVIDIA GeForce RTX 2080 Ti, Compute Capability 7.5
2025-02-17 14:58:54.021777: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:268] disabling MLIR crash reproducer, set env var
MLIR_CRASH_REPRODUCER_DIRECTORYto enable.
2025-02-17 14:58:54.823385: I external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:465] Loaded cuDNN version 8906
I0000 00:00:1739804355.312793 890492 device_compiler.h:188] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process.
ClearML Monitor: Could not detect iteration reporting, falling back to iterations as seconds-from-start
2025-02-17 15:02:43,896 - clearml - INFO - NaN value encountered. Reporting it as '0.0'. Use clearml.Logger.set_reporting_nan_value to assign another value
Saved artifact at '/tmp/tmpmfu4qumg/model'. The following endpoints are available:
* Endpoint 'serve'
args_0 (POSITIONAL_ONLY): TensorSpec(shape=(None, 1, 256, 256, 1), dtype=tf.float32, name='keras_tensor_51')
Output Type:
TensorSpec(shape=(None, 256, 256, 3), dtype=tf.float32, name=None)
Captures:
139737631767920: TensorSpec(shape=(1, 1, 1, 1, 1), dtype=tf.float32, name=None)
139737631767040: TensorSpec(shape=(1, 1, 1, 1, 1), dtype=tf.float32, name=None)
139737631768272: TensorSpec(shape=(1, 1, 256, 256, 1), dtype=tf.float32, name=None)
139753012177808: TensorSpec(shape=(), dtype=tf.resource, name=None)
139753012178864: TensorSpec(shape=(), dtype=tf.resource, name=None)
139753012176224: TensorSpec(shape=(), dtype=tf.resource, name=None)
139753012180624: TensorSpec(shape=(), dtype=tf.resource, name=None)
139753012178512: TensorSpec(shape=(), dtype=tf.resource, name=None)
139753012179744: TensorSpec(shape=(), dtype=tf.resource, name=None)
139753054298080: TensorSpec(shape=(), dtype=tf.resource, name=None)
139753012174288: TensorSpec(shape=(), dtype=tf.resource, name=None)
139753012176400: TensorSpec(shape=(), dtype=tf.resource, name=None)
139753012181328: TensorSpec(shape=(), dtype=tf.resource, name=None)
139753012180976: TensorSpec(shape=(), dtype=tf.resource, name=None)
139753012180800: TensorSpec(shape=(), dtype=tf.resource, name=None)
139753012182208: TensorSpec(shape=(), dtype=tf.resource, name=None)
139753012182560: TensorSpec(shape=(), dtype=tf.resource, name=None)
139753012181680: TensorSpec(shape=(), dtype=tf.resource, name=None)
139753012181504: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737633002192: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737633002368: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737633003424: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737633003600: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737633004832: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737633005184: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737633003952: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737633003776: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737633006064: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737633006240: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737633006768: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737633006592: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737633008000: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737633008352: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737633007472: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737633007296: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737633009056: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737633009232: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737633009760: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737633009584: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737633009408: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737633011168: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737633008704: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737633010464: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737633012400: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737633012576: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737633012928: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737633013104: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737631753312: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737631753664: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737631752432: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737631752256: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737631754720: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737631754896: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737631755424: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737631755248: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737631755072: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737631756832: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737631754544: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737631756128: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737631760704: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737631760880: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737631761408: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737631761232: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737631761056: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737631762816: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737631760528: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737631762112: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737631763872: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737631764048: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737631764576: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737631764400: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737631764224: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737631765984: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737631763696: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737631765280: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737631766688: TensorSpec(shape=(), dtype=tf.resource, name=None)
139737631766864: TensorSpec(shape=(), dtype=tf.resource, name=None)
2025-02-17 15:02:46,852 - clearml.Task - INFO - Completed model upload to
None1939/1939 - 242s - 125ms/step - dice_Bg: nan - dice_Lumen: nan - dice_all: nan - loss: nan - val_dice_Bg: nan - val_dice_Lumen: nan - val_dice_all: nan - val_loss: nan
Epoch 2/150
This is what I get when running the exact same training session without clearml
Epoch 1/150
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1739806371.262488 897794 service.cc:145] XLA service 0x7fc058066d20 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1739806371.262578 897794 service.cc:153] StreamExecutor device (0): NVIDIA GeForce RTX 2080 Ti, Compute Capability 7.5
2025-02-17 15:32:51.772357: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:268] disabling MLIR crash reproducer, set env var MLIR_CRASH_REPRODUCER_DIRECTORY
to enable.
2025-02-17 15:32:52.532296: I external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:465] Loaded cuDNN version 8906
I0000 00:00:1739806392.856744 897794 device_compiler.h:188] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process.
1939/1939 - 233s - 120ms/step - dice_Bg: 0.9607 - dice_Lumen: 0.8216 - dice_all: 0.8216 - loss: 0.1238 - val_dice_Bg: 0.9862 - val_dice_Lumen: 0.8429 - val_dice_all: 0.8429 - val_loss: 0.0908
Epoch 2/150
1939/1939 - 203s - 105ms/step - dice_Bg: 0.9752 - dice_Lumen: 0.8748 - dice_all: 0.8748 - loss: 0.0816 - val_dice_Bg: 0.9889 - val_dice_Lumen: 0.8836 - val_dice_all: 0.8836 - val_loss: 0.0679
This is what I get when running w/o clearmlEpoch 1/150
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1739806371.262488 897794 service.cc:145] XLA service 0x7fc058066d20 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1739806371.262578 897794 service.cc:153] StreamExecutor device (0): NVIDIA GeForce RTX 2080 Ti, Compute Capability 7.5
2025-02-17 15:32:51.772357: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:268] disabling MLIR crash reproducer, set env var
MLIR_CRASH_REPRODUCER_DIRECTORYto enable.
2025-02-17 15:32:52.532296: I external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:465] Loaded cuDNN version 8906
I0000 00:00:1739806392.856744 897794 device_compiler.h:188] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process.
1939/1939 - 233s - 120ms/step - dice_Bg: 0.9607 - dice_Lumen: 0.8216 - dice_all: 0.8216 - loss: 0.1238 - val_dice_Bg: 0.9862 - val_dice_Lumen: 0.8429 - val_dice_all: 0.8429 - val_loss: 0.0908
Epoch 2/150
1939/1939 - 203s - 105ms/step - dice_Bg: 0.9752 - dice_Lumen: 0.8748 - dice_all: 0.8748 - loss: 0.0816 - val_dice_Bg: 0.9889 - val_dice_Lumen: 0.8836 - val_dice_all: 0.8836 - val_loss: 0.0679
Hi DizzyButterfly4 , can you please put the logs in the a more readable format like
this is much nicer for logs
Also please add a code snippet that reproduces this
Hi DizzyButterfly4 , can you provide some snippet that reproduces this behaviour?
The only difference between the two runs is that in one run project_name
is and empty string (in which case all is OK), and in the other case project_name
has a value
Hi DizzyButterfly4 , can you provide a stand alone code snippet that reproduces this behaviour?
The project is many 1000s of lines long. It fails in the model.fit TF command. The only thing different from other versions which work is the loss function - which I share below. The relevant class is BoundaryWithCategoricalDiceLoss
which is called with boundary_loss_type = "GRAD"
. When I use the loss with boundary_loss_type = "MSE"
all works fine. This class is a subclass of CategoricalDiceLoss
which is a sub-class of keras.losses.Loss
from typing import Dict, Iterable, Any, Optional, Union, Literal
import tensorflow as tf
import keras
from src.utils.arrays.spatial import get_segmentation_boundary
class CategoricalDiceLoss(keras.losses.Loss):
def __init__(self, include_background: bool = False, from_logits: bool = True,
class_weights: Union[Literal['uniform', 'linear', 'square'], Dict[Any, float]] = 'uniform',
name: str = 'categorical_dice', **kwargs):
super().__init__(name=name, **kwargs)
self._include_background = include_background
self._from_logits = from_logits
if type(class_weights) is dict:
self._class_weights = [w for w in class_weights.values()]
else:
self._class_weights = class_weights
def call(self, y_true, y_pred):
## If the input is added to the output, we need to remove it before calculating the loss
if y_pred.shape[-1] > y_true.shape[-1]:
y_pred = y_pred[...,:-1]
if self._from_logits:
y_pred = tf.nn.softmax(y_pred, axis=-1)
num_classes = tf.shape(y_true)[-1]
batch_size = tf.shape(y_true)[0]
y_true = tf.reshape(y_true, [batch_size, -1, num_classes])
y_pred = tf.reshape(y_pred, [batch_size, -1, num_classes])
match self._class_weights:
case "uniform":
class_weights = tf.tile(tf.constant([1.0], tf.float32), tf.reshape(num_classes, [-1]))
if not self._include_background:
class_weights = tf.concat([[0.0], class_weights[1:]], axis=0)
class_weights = class_weights / tf.reduce_sum(class_weights)
case "linear":
raise Exception("Not implemented yet")
# class_weights = tf.reduce_sum(y_true, axis=1)
# if not self._include_background:
# class_weights = tf.concat([tf.zeros((batch_size, 1)), class_weights[:,1:]], axis=-1)
# class_weights = class_weights / tf.reduce_sum(class_weights)
case "square":
raise Exception("Not implemented yet")
case _:
class_weights = self._class_weights
class_weights = class_weights / tf.reduce_sum(class_weights)
products = tf.reduce_sum(y_true * y_pred, axis=1)
sums = tf.reduce_sum(y_true, axis=1) + tf.reduce_sum(y_pred, axis=1)
dices = (2.0 * products) / (sums + 1e-8)
### The term (tf.reduce_sum(y_true, axis = 1) > 0) is used to ignore classes that are not present in the ground truth
weighted_loss = tf.reduce_sum((1-dices) *
tf.cast((tf.reduce_sum(y_true, axis = 1) > 0), tf.float32) *
class_weights, axis=-1)
return weighted_loss
class BoundaryWithCategoricalDiceLoss(CategoricalDiceLoss):
## CR David - Support for class_weights not implemented yet
def __init__(self, thickness: Optional[int]=2, alpha: float = 0.5,
boundary_loss_type: Literal["CE", "MSE", "GRAD"] = "CE",
boundary_channels: Optional[Iterable[int]] = None, **kwargs):
super().__init__(**kwargs)
self.thickness = thickness
self.alpha = alpha
self.boundary_channels = None if boundary_channels is None else list(boundary_channels)
self.boundary_loss_type = boundary_loss_type
if self.boundary_loss_type == "GRAD":
self.sobel_x = tf.constant([[-1, 0, 1],
[-2, 0, 2],
[-1, 0, 1]], dtype=tf.float32)
self.sobel_y = tf.constant([[-1, -2, -1],
[ 0, 0, 0],
[ 1, 2, 1]], dtype=tf.float32)
# Reshape filters
self.sobel_x = tf.reshape(self.sobel_x, [3, 3, 1, 1])
self.sobel_y = tf.reshape(self.sobel_y, [3, 3, 1, 1])
def _get_boundary(self, y_true):
if self.boundary_channels is None:
y_true_all_labels = tf.cast(tf.expand_dims(tf.reduce_any(y_true[...,1:] > 0.0, axis=-1), axis=-1), tf.float32)
else:
y_true_all_labels = tf.cast(tf.expand_dims(tf.reduce_any(tf.gather(y_true, self.boundary_channels, axis=-1) > 0.0, axis=-1), axis=-1), tf.float32)
return get_segmentation_boundary(y_true_all_labels, thickness=self.thickness)
def _calc_gradient(self, input):
dx = tf.nn.conv2d(input, self.sobel_x, strides=[1, 1, 1, 1], padding='SAME')
dy = tf.nn.conv2d(input, self.sobel_y, strides=[1, 1, 1, 1], padding='SAME')
grad = tf.sqrt(tf.square(dx) + tf.square(dy))
grad = ((grad - tf.reduce_min(grad, axis=[1,2], keepdims=True)) + 1e-8) / ((tf.reduce_max(grad) - tf.reduce_min(grad, axis=[1,2], keepdims=True)) + 1e-8)
return grad
def _CE_loss(self, y_true, y_pred, roi_mask):
loss = tf.keras.losses.CategoricalCrossentropy(reduction=None)(y_true, y_pred)
masked_loss = tf.reduce_mean(loss * roi_mask, axis=[1, 2])
return masked_loss
def _GRAD_loss(self, x, y_pred, roi_mask):
x_grad = self._calc_gradient(x)
y_pred_grad = self._calc_gradient(y_pred)
loss = tf.squeeze(tf.sqrt(tf.square(x_grad - y_pred_grad) + 1e-8), axis=-1)
masked_loss = tf.reduce_mean(loss * roi_mask, axis=[1, 2])
return masked_loss
def _MSE_loss(self, y_true, y_pred, roi_mask):
loss = tf.reduce_sum(tf.sqrt(tf.square(y_true - y_pred) + 1e-8), axis = -1)
masked_loss = tf.reduce_mean(loss * roi_mask, axis=[1, 2])
return masked_loss
def call(self, y_true, y_pred):
## If the input is added to the output, we need to remove it before calculating the loss
if y_pred.shape[-1] > y_true.shape[-1]:
if self.boundary_loss_type == "GRAD":
x = y_pred[...,-1:]
if self.boundary_channels is not None:
y_pred_for_grad = tf.reduce_sum(tf.gather(y_pred, self.boundary_channels, axis=-1), axis=-1, keepdims=True)
else:
y_pred_for_grad = tf.reduce_sum(y_pred[...,1:-1], axis=-1, keepdims=True)
y_pred = y_pred[...,:-1]
dice_loss = super().call(y_true, y_pred)
if self._from_logits:
y_pred = tf.nn.softmax(y_pred, axis=-1)
# y true is sometime in the 5 rank meaning the shape is [batch_size, 1, 256, 256, 2]
# to handle cmpatibility with the boundary function we need to remove the 1 rank if it exists:
if len(y_true.shape) == 5:
y_true = tf.squeeze(y_true, axis=1)
if len(y_pred.shape) == 5:
y_pred = tf.squeeze(y_pred, axis=1)
boundary_mask = tf.squeeze(self._get_boundary(y_true), axis=-1)
match self.boundary_loss_type:
case "CE":
boundary_loss = self._CE_loss(y_true, y_pred, boundary_mask)
case "MSE":
boundary_loss = self._MSE_loss(y_true, y_pred, boundary_mask)
case "GRAD":
boundary_loss = self._GRAD_loss(x, y_pred_for_grad, boundary_mask)
#print(f"Boundary loss: {boundary_loss}")
losses_vector = self.alpha * dice_loss + (1-self.alpha) * boundary_loss
return losses_vector