https://stackoverflow.com/questions/60860121/plotly-how-to-make-an-annotated-confusion-matrix-using-a-heatmap
MagnificentSeaurchin79 see plotly example here:
https://allegro.ai/clearml/docs/docs/examples/reporting/plotly_reporting.html
thanks! that was the script I used..but for same reason making two sbs was a bit more complicated than just stacking two..
but I was finally able to do it:
btw, I think this should be the output of report_confusion_matrix
...what do you think?
I would like to be able to compare sbs train/val confusion matrices
like what I see in the debug samples:
Hi MagnificentSeaurchin79
Unfortunately there is currently no way to reorder the plots, but you have a valid point. I suggest a GitHub UX issue ?
Regrading the debug samples, the difference is that the confutation matrix report is actually metadata, you can get these numbers by the API or the download, but the debug samples are static images ...
BTW: you can try to produce an interactive side by side confusion matrix with plotly, and use report_plotly_figure
FYI, in case it is useful for someone else:
` import tensorflow as tf
import numpy as np
from plotly.subplots import make_subplots
import plotly.graph_objects as go
import itertools
from clearml import Task
def get_trace(z, series, classes, colorscale='blues', showscale=True, verbose=False):
if verbose:
print(z)
ind = '1' if series=='train' else '2'
trace = dict(
type="heatmap",
z=z[::-1],
x=classes,
y=classes[::-1],
colorscale=colorscale,
showscale=True,
reversescale=False,
zmin=0,
zmax=1,
xaxis=f'x{ind}',
yaxis=f'y{ind}',
name=series
)
return trace
def add_annotations(z, series, classes, format='.2f', zmid=0.5, min_text_color='black',
max_text_color='white'):
num_classes = z.shape[0]
annotation_text = np.array([[f'{y:{format}}' for y in x] for x in z[::-1]])
ind = '1' if series=='train' else '2'
annotations = []
for row, col in itertools.product(range(num_classes), repeat=2):
val = z[col, row]
font_color = min_text_color if val < zmid else max_text_color
annotations.append(
go.layout.Annotation(
text=str(annotation_text[::-1][col][row]),
x=classes[row],
y=classes[col],
xref=f'x{ind}',
yref=f'y{ind}',
font=dict(color=font_color),
showarrow=False)
)
return annotations
def add_subtitles():
annotations = []
annotations.append(dict(font=dict(color="black",size=15),
x=0.5,
y=-0.2,
showarrow=False,
text='Predicted label',
xref='paper',
yref='paper'))
annotations.append(dict(font=dict(color="black",size=15),
x=-0.1,
y=0.5,
showarrow=False,
text='True label',
textangle=-90,
xref='paper',
yref='paper'))
return annotations
def plot_train_val_confusion_matrices(z_train,
z_val,
classes_labels,
width=1000,
height=500
):
fig = make_subplots(rows=1, cols=2, subplot_titles=('Training', 'Validation'))
fig.add_trace(get_trace(z_train, 'train', classes_labels))
annotations = add_subtitles()
annotations.extend(add_annotations(z_train, 'train', classes_labels))
fig.add_trace(get_trace(z_val, 'val', classes_labels))
annotations.extend(add_annotations(z_val, 'val', classes_labels))
layout = go.Layout(annotations=annotations)
fig.update_layout(layout)
annotations.extend(add_subtitles())
fig.update_layout(margin=dict(l=100, r=100, t=100, b=100), width=width, height=height,)
fig.update_layout(title_text='Confusion matrix')
layout = go.Layout(annotations=annotations)
fig.update_layout(layout)
return fig
task = Task.init('scripts')
num_classes = 6
classes = [str(i) for i in range(num_classes)]
classes_labels = ['label_'+str(i) for i in range(num_classes)]
true_labels = np.random.randint(num_classes, size=100)
predicted_labels = np.random.randint(num_classes, size=100)
conf = tf.math.confusion_matrix(true_labels, predicted_labels).numpy()
z_train = conf/conf.sum(axis=1, keepdims=True)
true_labels = np.random.randint(num_classes, size=100)
predicted_labels = np.random.randint(num_classes, size=100)
conf = tf.math.confusion_matrix(true_labels, predicted_labels).numpy()
z_val = conf/conf.sum(axis=1, keepdims=True)
fig = plot_train_val_confusion_matrices(z_train, z_val, classes_labels, width=1000, height=500)
task.get_logger().report_plotly(title='Confusion matrix', series='', iteration=0, figure=fig) `