Plotting
Overview
The plotting module provides comprehensive visualization utilities:
Training curves: Loss and accuracy over epochs
Confusion matrices: Classification performance heatmaps
Sample images: Visualize dataset examples
Probability distributions: Class prediction confidence
Evaluation curves: ROC and precision-recall curves
Optimization results: Optuna trial visualizations
All functions return matplotlib figure and axes objects for further customization.
Example usage
Learning curves:
from image_classification_tools.pytorch.plotting import plot_learning_curves
import matplotlib.pyplot as plt
# After training
fig, axes = plot_learning_curves(history)
plt.show()
Confusion matrix:
from image_classification_tools.pytorch.plotting import plot_confusion_matrix
fig, ax = plot_confusion_matrix(true_labels, predictions, class_names)
plt.title('Test Set Confusion Matrix')
plt.show()
Sample images:
from image_classification_tools.pytorch.plotting import plot_sample_images
fig, axes = plot_sample_images(dataset, class_names, nrows=2, ncols=5)
plt.show()
Class probability distributions:
from image_classification_tools.pytorch.plotting import plot_class_probability_distributions
# Get predicted probabilities
model.eval()
all_probs = []
with torch.no_grad():
for images, _ in test_loader:
outputs = model(images)
probs = torch.softmax(outputs, dim=1)
all_probs.append(probs.cpu().numpy())
all_probs = np.concatenate(all_probs, axis=0)
fig, axes = plot_class_probability_distributions(all_probs, class_names)
plt.show()
Evaluation curves:
from image_classification_tools.pytorch.plotting import plot_evaluation_curves
fig, (ax1, ax2) = plot_evaluation_curves(true_labels, all_probs, class_names)
plt.show()
Optimization results:
from image_classification_tools.pytorch.plotting import plot_optimization_results
# After Optuna study
fig, axes = plot_optimization_results(study)
plt.show()