Classification Metrics¶
Classification metrics quantify the performance of a model predicting categorical labels. Unlike regression, where distance is key, classification evaluation focuses on the correct assignment of samples to their respective classes.
Elements of evaluation¶
Most classification metrics are derived from four fundamental outcomes in a binary classification task:
- True Positives (TP): The model correctly predicted the positive class.
- True Negatives (TN): The model correctly predicted the negative class.
- False Positives (FP): The model predicted the positive class, but the actual class was negative (Type I error).
- False Negatives (FN): The model predicted the negative class, but the actual class was positive (Type II error).
These values are usually aggregated in a Confusion Matrix.
# Uncomment the next line and run this cell to install sorix
#!pip install 'sorix @ git+https://github.com/Mitchell-Mirano/sorix.git@main'
import numpy as np
from sorix import tensor
from sorix.metrics import (
accuracy_score,
confusion_matrix,
precision_score,
recall_score,
f1_score,
classification_report
)
Confusion Matrix¶
The confusion matrix is a table that summarizes the performance of a classification algorithm. Each row of the matrix represents the instances in an actual class, while each column represents the instances in a predicted class.
For binary classification:
$$ \begin{pmatrix} TN & FP \\ FN & TP \end{pmatrix} $$
Example¶
Let's see a confusion matrix for a 3-class problem.
y_true = tensor([2, 0, 2, 2, 0, 1])
y_pred = tensor([0, 0, 2, 2, 0, 2])
cm = confusion_matrix(y_true, y_pred)
print("Confusion Matrix:")
print(cm)
Confusion Matrix: [[2 0 0] [0 0 1] [1 0 2]]
Accuracy Score¶
Accuracy is the ratio of correct predictions to the total number of input samples. It is suitable when classes are balanced.
$$ \text{Accuracy} = \frac{TP + TN}{TP + TN + FP + FN} = \frac{1}{n} \sum_{i=1}^{n} \mathbb{I}(y_i = \hat{y}_i) $$
y_true = tensor([0, 1, 1, 0, 1, 1])
y_pred = tensor([0, 1, 0, 0, 1, 1])
# Correct: 5, Total: 6
print(f"Accuracy Score: {accuracy_score(y_true, y_pred):.4f}")
Accuracy Score: 0.8333
Precision¶
Precision (also called positive predictive value) measures the model's accuracy in predicting the positive class. It answers: "Of all the samples predicted as positive, how many were actually positive?"
$$ \text{Precision} = \frac{TP}{TP + FP} $$
y_true = tensor([1, 0, 0, 1, 1, 0])
y_pred = tensor([1, 1, 0, 1, 0, 0])
# Actual positives: [1, 1, 1]
# Predicted positives: [1, 1, 1] at indices 0, 1, 3
# TP: 2 (indices 0 and 3), FP: 1 (index 1)
precision = precision_score(y_true, y_pred, average='binary')
print(f"Binary Precision: {precision:.4f}") # 2 / (2 + 1) = 0.6667
Binary Precision: 0.6667
Recall (Sensitivity)¶
Recall measures the ability of the classifier to find all the positive samples. It answers: "Of all the actual positive samples, how many did we correctly identify?"
$$ \text{Recall} = \frac{TP}{TP + FN} $$
y_true = tensor([1, 0, 0, 1, 1, 0])
y_pred = tensor([1, 1, 0, 1, 0, 0])
# TP: 2, FN: 1 (index 4 was positive but predicted as 0)
recall = recall_score(y_true, y_pred, average='binary')
print(f"Binary Recall: {recall:.4f}") # 2 / (2 + 1) = 0.6667
Binary Recall: 0.6667
F1-Score¶
The F1-Score is the harmonic mean of precision and recall. It is useful when you need to balance both metrics and is preferred over accuracy for imbalanced datasets.
$$ \text{F1} = 2 \cdot \frac{\text{Precision} \cdot \text{Recall}}{\text{Precision} + \text{Recall}} $$
y_true = tensor([1, 0, 0, 1, 1, 0])
y_pred = tensor([1, 1, 0, 1, 0, 0])
f1 = f1_score(y_true, y_pred, average='binary')
print(f"Binary F1-Score: {f1:.4f}")
Binary F1-Score: 0.6667
Multi-class Classification Report¶
The classification_report allows you to see all the above metrics per class in a multi-class environment, plus global metrics like macro and weighted averages.
y_true = tensor([0, 1, 2, 2, 0, 1])
y_pred = tensor([0, 0, 2, 2, 0, 2])
print(classification_report(y_true, y_pred))
precision recall f1-score support 0 0.67 1.00 0.80 2 1 0.00 0.00 0.00 2 2 0.67 1.00 0.80 2 accuracy 0.67 6 macro avg 0.44 0.67 0.53 6 weighted avg 0.44 0.67 0.53 6