Last active
October 31, 2024 05:41
-
-
Save SHIMURA0/25877f154c63310b747660c946f8b148 to your computer and use it in GitHub Desktop.
This Python script is used for visualization of detailed evaluation including confusion matrix (regular and normalized) and performance metrics including precision, recall and f1 score.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import List, Dict, Tuple, Union, Any | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
from sklearn.metrics import precision_score, recall_score, f1_score | |
def plot_advanced_confusion_matrix( | |
cm: np.ndarray, | |
class_names: List[str], | |
save_path: str | |
) -> None: | |
""" | |
Plot an enhanced confusion matrix analysis with multiple metrics and visualizations. | |
Args: | |
cm (np.ndarray): Confusion matrix array where rows represent true labels | |
and columns represent predicted labels | |
class_names (List[str]): List of class names for labeling | |
save_path (str): Path where the plot will be saved | |
Returns: | |
None: Saves the plot to the specified path | |
The function creates a comprehensive visualization including: | |
1. Raw confusion matrix heatmap | |
2. Normalized confusion matrix heatmap | |
3. Per-class performance metrics (Precision, Recall, F1) | |
4. Performance comparison bar chart | |
""" | |
# Calculate performance metrics for each class | |
precision = np.diag(cm) / np.sum(cm, axis=0) # Precision = TP / (TP + FP) | |
recall = np.diag(cm) / np.sum(cm, axis=1) # Recall = TP / (TP + FN) | |
f1 = 2 * (precision * recall) / (precision + recall) # F1 = 2 * (P * R) / (P + R) | |
# Create figure with GridSpec for subplot arrangement | |
fig = plt.figure(figsize=(20, 12)) | |
gs = plt.GridSpec(2, 3, figure=fig) | |
# 1. Raw Confusion Matrix Heatmap (Top Left) | |
ax1 = fig.add_subplot(gs[0, 0]) | |
sns.heatmap( | |
cm, | |
annot=True, | |
fmt='d', # Display as integers | |
cmap='Blues', | |
ax=ax1, | |
annot_kws={'size': 8} | |
) | |
ax1.set_xlabel('Predicted Label', fontsize=10) | |
ax1.set_ylabel('True Label', fontsize=10) | |
ax1.set_title('Confusion Matrix', fontsize=12, pad=10) | |
ax1.set_xticklabels(class_names, rotation=45) | |
ax1.set_yticklabels(class_names, rotation=0) | |
# 2. Normalized Confusion Matrix (Top Middle) | |
ax2 = fig.add_subplot(gs[0, 1]) | |
norm_cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] # Normalize by row | |
sns.heatmap( | |
norm_cm, | |
annot=True, | |
fmt='.2%', # Display as percentages | |
cmap='RdYlBu_r', | |
ax=ax2, | |
annot_kws={'size': 8} | |
) | |
ax2.set_xlabel('Predicted Label', fontsize=10) | |
ax2.set_ylabel('True Label', fontsize=10) | |
ax2.set_title('Normalized Confusion Matrix', fontsize=12, pad=10) | |
ax2.set_xticklabels(class_names, rotation=45) | |
ax2.set_yticklabels(class_names, rotation=0) | |
# 3. Performance Metrics Heatmap (Top Right) | |
ax3 = fig.add_subplot(gs[0, 2]) | |
metrics_data = np.array([precision, recall, f1]).T | |
sns.heatmap( | |
metrics_data, | |
annot=True, | |
fmt='.2%', | |
cmap='YlOrRd', | |
xticklabels=['Precision', 'Recall', 'F1'], | |
yticklabels=class_names, | |
ax=ax3, | |
annot_kws={'size': 8} | |
) | |
ax3.set_title('Performance Metrics by Class', fontsize=12, pad=10) | |
# 4. Performance Metrics Bar Chart (Bottom) | |
ax4 = fig.add_subplot(gs[1, :]) | |
# Calculate overall metrics | |
overall_precision = np.mean(precision) | |
overall_recall = np.mean(recall) | |
overall_f1 = np.mean(f1) | |
# Combine overall and per-class metrics | |
all_precision = np.concatenate(([overall_precision], precision)) | |
all_recall = np.concatenate(([overall_recall], recall)) | |
all_f1 = np.concatenate(([overall_f1], f1)) | |
all_class_names = ['Overall'] + class_names | |
# Plot grouped bar chart | |
width = 0.25 | |
x = np.arange(len(all_class_names)) | |
ax4.bar(x - width, all_precision, width, label='Precision', color='skyblue') | |
ax4.bar(x, all_recall, width, label='Recall', color='lightgreen') | |
ax4.bar(x + width, all_f1, width, label='F1', color='salmon') | |
def add_value_labels(ax: plt.Axes, rects: List[plt.Rectangle]) -> None: | |
"""Add value labels on top of each bar in the chart.""" | |
for rect in rects: | |
height = rect.get_height() | |
ax.text( | |
rect.get_x() + rect.get_width()/2., | |
height, | |
f'{height:.2%}', | |
ha='center', | |
va='bottom', | |
rotation=0, | |
fontsize=8 | |
) | |
# Configure bar chart | |
ax4.set_ylabel('Score') | |
ax4.set_title('Performance Metrics Comparison') | |
ax4.set_xticks(x) | |
ax4.set_xticklabels(all_class_names, rotation=0) | |
ax4.legend() | |
ax4.grid(True, linestyle='--', alpha=0.7) | |
# Add value labels to bars | |
add_value_labels(ax4, ax4.patches) | |
# Adjust layout and save | |
plt.tight_layout() | |
plt.savefig(f'{save_path}.png', dpi=300, bbox_inches='tight', pad_inches=0.1) | |
plt.close() | |
def analyze_errors( | |
cm: np.ndarray, | |
class_names: List[str] | |
) -> List[Dict[str, Union[str, int, float, List[Tuple[str, int]]]]]: | |
""" | |
Analyze error patterns in the confusion matrix. | |
Args: | |
cm (np.ndarray): Confusion matrix array | |
class_names (List[str]): List of class names | |
Returns: | |
List[Dict]: List of dictionaries containing error analysis for each class | |
Each dictionary contains: | |
- true_class: name of the true class | |
- total_samples: total number of samples for this class | |
- correct_predictions: number of correct predictions | |
- common_mistakes: list of tuples (predicted_class, count) for most common errors | |
""" | |
n_classes = len(class_names) | |
error_analysis = [] | |
for true_class in range(n_classes): | |
# Find most common misclassifications | |
predictions = cm[true_class] | |
wrong_predictions = [ | |
(class_names[i], predictions[i]) | |
for i in range(n_classes) | |
if i != true_class and predictions[i] > 0 | |
] | |
wrong_predictions.sort(key=lambda x: x[C_1](), reverse=True) | |
if wrong_predictions: | |
error_analysis.append({ | |
'true_class': class_names[true_class], | |
'total_samples': np.sum(predictions), | |
'correct_predictions': predictions[true_class], | |
'common_mistakes': wrong_predictions[:3] # Top 3 most common errors | |
}) | |
return error_analysis | |
# Example usage | |
if __name__ == "__main__": | |
# Sample confusion matrix | |
cm = np.array([ | |
[94, 3, 0, 0, 0, 0, 0, 1], | |
[6, 78, 6, 0, 1, 0, 2, 7], | |
[0, 7, 86, 2, 2, 1, 1, 2], | |
[0, 0, 0, 99, 1, 0, 0, 0], | |
[2, 2, 3, 2, 47, 17, 13, 13], | |
[2, 2, 3, 2, 22, 26, 23, 20], | |
[2, 2, 2, 1, 12, 15, 29, 38], | |
[2, 4, 1, 0, 1, 1, 6, 87] | |
]) | |
class_names = [ | |
'Class 0', 'Class 1', 'Class 2', 'Class 3', | |
'Class 4', 'Class 5', 'Class 6', 'Class 7' | |
] | |
# Generate visualizations | |
plot_advanced_confusion_matrix( | |
cm, | |
class_names, | |
'/home/zhl/regression_and_prediction/2024-10-31/advanced_confusion_matrix' | |
) | |
# Perform error analysis | |
error_analysis = analyze_errors(cm, class_names) | |
# Print error analysis results | |
for analysis in error_analysis: | |
print(f"\nAnalysis for {analysis['true_class']}:") | |
print(f"Total samples: {analysis['total_samples']}") | |
print( | |
f"Correct predictions: {analysis['correct_predictions']} " | |
f"({analysis['correct_predictions']/analysis['total_samples']:.2%})" | |
) | |
print("Most common mistakes:") | |
for pred_class, count in analysis['common_mistakes']: | |
print( | |
f" - Predicted as {pred_class}: {count} times " | |
f"({count/analysis['total_samples']:.2%})" | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment