symmetrized confusion matrix #624
yanivboker
started this conversation in
Show and tell
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
I made a symmetrized confusion matrix so it will sum up the confusion from pairs of labels (shirt->t-shirt and t-shirt->shirt)
from torchmetrics import ConfusionMatrix
from mlxtend.plotting import plot_confusion_matrix
2. Setup confusion matrix instance and compare predictions to targets
confmat = ConfusionMatrix(num_classes=len(class_names), task='multiclass')
confmat_tensor = confmat(preds=y_pred_tensor,
target=test_data.targets)
3. Plot the confusion matrix
fig, ax = plot_confusion_matrix(
conf_mat=confmat_tensor.numpy(), # matplotlib likes working with NumPy
class_names=class_names, # turn the row and column labels into class names
figsize=(10, 7)
);
Function to symmetrize a confusion matrix
def symmetrize_confusion_matrix(matrix):
return matrix + matrix.T - np.diag(matrix.diagonal())
Function to plot the symmetrized confusion matrix
def plot_symmetrized_confusion_matrix(matrix, class_names):
plt.figure(figsize=(10, 10))
plt.imshow(matrix, interpolation='nearest', cmap=plt.cm.Blues)
plt.title('Symmetrized Confusion Matrix')
plt.colorbar()
tick_marks = np.arange(len(class_names))
plt.xticks(tick_marks, class_names, rotation=45)
plt.yticks(tick_marks, class_names)
Class names for Fashion-MNIST
class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
Create symmetrized confusion matrix and plot it
symmetrized_conf_matrix = symmetrize_confusion_matrix(confmat_tensor)
plot_symmetrized_confusion_matrix(symmetrized_conf_matrix, class_names)
Beta Was this translation helpful? Give feedback.
All reactions