Image Classification Tools Documentation
image-classification-tools is a lightweight PyTorch toolkit for building and training image classification models.
The package provides utilities for:
Loading and preprocessing image datasets
Training models with validation tracking
Evaluating model performance
Visualizing results and metrics
Optimizing hyperparameters with Optuna
Who should use this
This package is for developers who need to:
Build image classifiers for custom datasets
Prototype and compare different model architectures
Automate hyperparameter tuning
Evaluate and visualize model performance
The API works with any image classification task, from small datasets like MNIST to larger custom collections.
Installation
pip install image-classification-tools
Quick example
Minimal example classifying MNIST digits:
import torch
from pathlib import Path
from torchvision import datasets, transforms
from image_classification_tools.pytorch.data import (
load_datasets, prepare_splits, create_dataloaders
)
from image_classification_tools.pytorch.training import train_model
# Load data
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
# Load, split, and create dataloaders
train_dataset, test_dataset = load_datasets(
data_source=datasets.MNIST,
train_transform=transform,
eval_transform=transform,
download=True,
root=Path('./data/mnist')
)
train_dataset, val_dataset, test_dataset = prepare_splits(
train_dataset, test_dataset, train_val_split=0.8
)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_loader, val_loader, test_loader = create_dataloaders(
train_dataset, val_dataset, test_dataset,
batch_size=64,
preload_to_memory=True,
device=device
)
# Define model
model = torch.nn.Sequential(
torch.nn.Flatten(),
torch.nn.Linear(784, 128),
torch.nn.ReLU(),
torch.nn.Linear(128, 10)
)
# Train
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())
history = train_model(
model=model,
train_loader=train_loader,
val_loader=val_loader,
criterion=criterion,
optimizer=optimizer,
epochs=10
)
Demo project
For a complete example, see the CIFAR-10 classification demo: https://github.com/gperdrizet/CIFAR10
Documentation contents
User guide
API reference
Project links