Source code for image_classification_tools.pytorch.evaluation

'''Evaluation functions for models.'''

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader


[docs] def evaluate_model( model: nn.Module, test_loader: DataLoader ) -> tuple[float, np.ndarray, np.ndarray]: '''Evaluate model on test set. Note: Assumes data is already on the correct device. ''' model.eval() correct = 0 total = 0 all_predictions = [] all_labels = [] with torch.no_grad(): for images, labels in test_loader: outputs = model(images) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() all_predictions.extend(predicted.cpu().numpy()) all_labels.extend(labels.cpu().numpy()) accuracy = 100 * correct / total return accuracy, np.array(all_predictions), np.array(all_labels)