'''Evaluation functions for models.'''
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
[docs]
def evaluate_model(
model: nn.Module,
test_loader: DataLoader
) -> tuple[float, np.ndarray, np.ndarray]:
'''Evaluate model on test set.
Note: Assumes data is already on the correct device.
'''
model.eval()
correct = 0
total = 0
all_predictions = []
all_labels = []
with torch.no_grad():
for images, labels in test_loader:
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
all_predictions.extend(predicted.cpu().numpy())
all_labels.extend(labels.cpu().numpy())
accuracy = 100 * correct / total
return accuracy, np.array(all_predictions), np.array(all_labels)