Skip to content

Instantly share code, notes, and snippets.

@SHIMURA0
Last active October 31, 2024 05:41
Show Gist options
  • Save SHIMURA0/25877f154c63310b747660c946f8b148 to your computer and use it in GitHub Desktop.
Save SHIMURA0/25877f154c63310b747660c946f8b148 to your computer and use it in GitHub Desktop.
This Python script is used for visualization of detailed evaluation including confusion matrix (regular and normalized) and performance metrics including precision, recall and f1 score.
from typing import List, Dict, Tuple, Union, Any
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import precision_score, recall_score, f1_score
def plot_advanced_confusion_matrix(
cm: np.ndarray,
class_names: List[str],
save_path: str
) -> None:
"""
Plot an enhanced confusion matrix analysis with multiple metrics and visualizations.
Args:
cm (np.ndarray): Confusion matrix array where rows represent true labels
and columns represent predicted labels
class_names (List[str]): List of class names for labeling
save_path (str): Path where the plot will be saved
Returns:
None: Saves the plot to the specified path
The function creates a comprehensive visualization including:
1. Raw confusion matrix heatmap
2. Normalized confusion matrix heatmap
3. Per-class performance metrics (Precision, Recall, F1)
4. Performance comparison bar chart
"""
# Calculate performance metrics for each class
precision = np.diag(cm) / np.sum(cm, axis=0) # Precision = TP / (TP + FP)
recall = np.diag(cm) / np.sum(cm, axis=1) # Recall = TP / (TP + FN)
f1 = 2 * (precision * recall) / (precision + recall) # F1 = 2 * (P * R) / (P + R)
# Create figure with GridSpec for subplot arrangement
fig = plt.figure(figsize=(20, 12))
gs = plt.GridSpec(2, 3, figure=fig)
# 1. Raw Confusion Matrix Heatmap (Top Left)
ax1 = fig.add_subplot(gs[0, 0])
sns.heatmap(
cm,
annot=True,
fmt='d', # Display as integers
cmap='Blues',
ax=ax1,
annot_kws={'size': 8}
)
ax1.set_xlabel('Predicted Label', fontsize=10)
ax1.set_ylabel('True Label', fontsize=10)
ax1.set_title('Confusion Matrix', fontsize=12, pad=10)
ax1.set_xticklabels(class_names, rotation=45)
ax1.set_yticklabels(class_names, rotation=0)
# 2. Normalized Confusion Matrix (Top Middle)
ax2 = fig.add_subplot(gs[0, 1])
norm_cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] # Normalize by row
sns.heatmap(
norm_cm,
annot=True,
fmt='.2%', # Display as percentages
cmap='RdYlBu_r',
ax=ax2,
annot_kws={'size': 8}
)
ax2.set_xlabel('Predicted Label', fontsize=10)
ax2.set_ylabel('True Label', fontsize=10)
ax2.set_title('Normalized Confusion Matrix', fontsize=12, pad=10)
ax2.set_xticklabels(class_names, rotation=45)
ax2.set_yticklabels(class_names, rotation=0)
# 3. Performance Metrics Heatmap (Top Right)
ax3 = fig.add_subplot(gs[0, 2])
metrics_data = np.array([precision, recall, f1]).T
sns.heatmap(
metrics_data,
annot=True,
fmt='.2%',
cmap='YlOrRd',
xticklabels=['Precision', 'Recall', 'F1'],
yticklabels=class_names,
ax=ax3,
annot_kws={'size': 8}
)
ax3.set_title('Performance Metrics by Class', fontsize=12, pad=10)
# 4. Performance Metrics Bar Chart (Bottom)
ax4 = fig.add_subplot(gs[1, :])
# Calculate overall metrics
overall_precision = np.mean(precision)
overall_recall = np.mean(recall)
overall_f1 = np.mean(f1)
# Combine overall and per-class metrics
all_precision = np.concatenate(([overall_precision], precision))
all_recall = np.concatenate(([overall_recall], recall))
all_f1 = np.concatenate(([overall_f1], f1))
all_class_names = ['Overall'] + class_names
# Plot grouped bar chart
width = 0.25
x = np.arange(len(all_class_names))
ax4.bar(x - width, all_precision, width, label='Precision', color='skyblue')
ax4.bar(x, all_recall, width, label='Recall', color='lightgreen')
ax4.bar(x + width, all_f1, width, label='F1', color='salmon')
def add_value_labels(ax: plt.Axes, rects: List[plt.Rectangle]) -> None:
"""Add value labels on top of each bar in the chart."""
for rect in rects:
height = rect.get_height()
ax.text(
rect.get_x() + rect.get_width()/2.,
height,
f'{height:.2%}',
ha='center',
va='bottom',
rotation=0,
fontsize=8
)
# Configure bar chart
ax4.set_ylabel('Score')
ax4.set_title('Performance Metrics Comparison')
ax4.set_xticks(x)
ax4.set_xticklabels(all_class_names, rotation=0)
ax4.legend()
ax4.grid(True, linestyle='--', alpha=0.7)
# Add value labels to bars
add_value_labels(ax4, ax4.patches)
# Adjust layout and save
plt.tight_layout()
plt.savefig(f'{save_path}.png', dpi=300, bbox_inches='tight', pad_inches=0.1)
plt.close()
def analyze_errors(
cm: np.ndarray,
class_names: List[str]
) -> List[Dict[str, Union[str, int, float, List[Tuple[str, int]]]]]:
"""
Analyze error patterns in the confusion matrix.
Args:
cm (np.ndarray): Confusion matrix array
class_names (List[str]): List of class names
Returns:
List[Dict]: List of dictionaries containing error analysis for each class
Each dictionary contains:
- true_class: name of the true class
- total_samples: total number of samples for this class
- correct_predictions: number of correct predictions
- common_mistakes: list of tuples (predicted_class, count) for most common errors
"""
n_classes = len(class_names)
error_analysis = []
for true_class in range(n_classes):
# Find most common misclassifications
predictions = cm[true_class]
wrong_predictions = [
(class_names[i], predictions[i])
for i in range(n_classes)
if i != true_class and predictions[i] > 0
]
wrong_predictions.sort(key=lambda x: x[C_1](), reverse=True)
if wrong_predictions:
error_analysis.append({
'true_class': class_names[true_class],
'total_samples': np.sum(predictions),
'correct_predictions': predictions[true_class],
'common_mistakes': wrong_predictions[:3] # Top 3 most common errors
})
return error_analysis
# Example usage
if __name__ == "__main__":
# Sample confusion matrix
cm = np.array([
[94, 3, 0, 0, 0, 0, 0, 1],
[6, 78, 6, 0, 1, 0, 2, 7],
[0, 7, 86, 2, 2, 1, 1, 2],
[0, 0, 0, 99, 1, 0, 0, 0],
[2, 2, 3, 2, 47, 17, 13, 13],
[2, 2, 3, 2, 22, 26, 23, 20],
[2, 2, 2, 1, 12, 15, 29, 38],
[2, 4, 1, 0, 1, 1, 6, 87]
])
class_names = [
'Class 0', 'Class 1', 'Class 2', 'Class 3',
'Class 4', 'Class 5', 'Class 6', 'Class 7'
]
# Generate visualizations
plot_advanced_confusion_matrix(
cm,
class_names,
'/home/zhl/regression_and_prediction/2024-10-31/advanced_confusion_matrix'
)
# Perform error analysis
error_analysis = analyze_errors(cm, class_names)
# Print error analysis results
for analysis in error_analysis:
print(f"\nAnalysis for {analysis['true_class']}:")
print(f"Total samples: {analysis['total_samples']}")
print(
f"Correct predictions: {analysis['correct_predictions']} "
f"({analysis['correct_predictions']/analysis['total_samples']:.2%})"
)
print("Most common mistakes:")
for pred_class, count in analysis['common_mistakes']:
print(
f" - Predicted as {pred_class}: {count} times "
f"({count/analysis['total_samples']:.2%})"
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment