FYI, in case it is useful for someone else:
` import tensorflow as tf
import numpy as np
from plotly.subplots import make_subplots
import plotly.graph_objects as go
import itertools
from clearml import Task
def get_trace(z, series, classes, colorscale='blues', showscale=True, verbose=False):
if verbose:
print(z)
ind = '1' if series=='train' else '2'
trace = dict(
type="heatmap",
z=z[::-1],
x=classes,
y=classes[::-1],
colorscale=colorscale,
showscale=True,
reversescale=False,
zmin=0,
zmax=1,
xaxis=f'x{ind}',
yaxis=f'y{ind}',
name=series
)
return trace
def add_annotations(z, series, classes, format='.2f', zmid=0.5, min_text_color='black',
max_text_color='white'):
num_classes = z.shape[0]
annotation_text = np.array([[f'{y:{format}}' for y in x] for x in z[::-1]])
ind = '1' if series=='train' else '2'
annotations = []
for row, col in itertools.product(range(num_classes), repeat=2):
val = z[col, row]
font_color = min_text_color if val < zmid else max_text_color
annotations.append(
go.layout.Annotation(
text=str(annotation_text[::-1][col][row]),
x=classes[row],
y=classes[col],
xref=f'x{ind}',
yref=f'y{ind}',
font=dict(color=font_color),
showarrow=False)
)
return annotations
def add_subtitles():
annotations = []
annotations.append(dict(font=dict(color="black",size=15),
x=0.5,
y=-0.2,
showarrow=False,
text='Predicted label',
xref='paper',
yref='paper'))
annotations.append(dict(font=dict(color="black",size=15),
x=-0.1,
y=0.5,
showarrow=False,
text='True label',
textangle=-90,
xref='paper',
yref='paper'))
return annotations
def plot_train_val_confusion_matrices(z_train,
z_val,
classes_labels,
width=1000,
height=500
):
fig = make_subplots(rows=1, cols=2, subplot_titles=('Training', 'Validation'))
fig.add_trace(get_trace(z_train, 'train', classes_labels))
annotations = add_subtitles()
annotations.extend(add_annotations(z_train, 'train', classes_labels))
fig.add_trace(get_trace(z_val, 'val', classes_labels))
annotations.extend(add_annotations(z_val, 'val', classes_labels))
layout = go.Layout(annotations=annotations)
fig.update_layout(layout)
annotations.extend(add_subtitles())
fig.update_layout(margin=dict(l=100, r=100, t=100, b=100), width=width, height=height,)
fig.update_layout(title_text='Confusion matrix')
layout = go.Layout(annotations=annotations)
fig.update_layout(layout)
return fig
task = Task.init('scripts')
num_classes = 6
classes = [str(i) for i in range(num_classes)]
classes_labels = ['label_'+str(i) for i in range(num_classes)]
true_labels = np.random.randint(num_classes, size=100)
predicted_labels = np.random.randint(num_classes, size=100)
conf = tf.math.confusion_matrix(true_labels, predicted_labels).numpy()
z_train = conf/conf.sum(axis=1, keepdims=True)
true_labels = np.random.randint(num_classes, size=100)
predicted_labels = np.random.randint(num_classes, size=100)
conf = tf.math.confusion_matrix(true_labels, predicted_labels).numpy()
z_val = conf/conf.sum(axis=1, keepdims=True)
fig = plot_train_val_confusion_matrices(z_train, z_val, classes_labels, width=1000, height=500)
task.get_logger().report_plotly(title='Confusion matrix', series='', iteration=0, figure=fig) `