Source code for image_classification_tools.pytorch.data

'''Data loading and preprocessing functions for image classification datasets.

This module provides utilities for loading datasets (including CIFAR-10) and creating
PyTorch DataLoaders with support for custom transforms and device preloading.
'''

# import os
import shutil
from pathlib import Path
from typing import Tuple

import torch
from torchvision import datasets, transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader, Dataset, TensorDataset, Subset


[docs] def load_dataset( data_source: Path | type, transform: transforms.Compose, train: bool = True, download: bool = False, **dataset_kwargs ) -> Dataset: '''Load a single dataset from a directory or PyTorch dataset class. This function provides a flexible interface for loading image classification datasets. It supports both PyTorch built-in datasets (CIFAR-10, CIFAR-100, MNIST, etc.) and custom datasets stored in directories following the ImageFolder structure. Args: data_source: Either a Path to a directory containing train/ or test/ subdirectory, or a PyTorch dataset class (e.g., datasets.CIFAR10) transform: Transforms to apply to the data train: If True, load training data. If False, load test data (default: True) download: Whether to download the dataset if using a PyTorch dataset class. Ignored for directory-based datasets. **dataset_kwargs: Additional keyword arguments passed to the dataset class (e.g., root='data/pytorch/cifar10') Returns: Dataset object Examples: # Load CIFAR-10 training data train_dataset = load_dataset( data_source=datasets.CIFAR10, transform=transform, train=True, root='data/cifar10' ) # Load from ImageFolder train_dataset = load_dataset( data_source=Path('data/my_dataset'), transform=transform, train=True ) ''' if isinstance(data_source, Path): # Directory-based dataset using ImageFolder subdir = 'train' if train else 'test' data_dir = data_source / subdir if not data_dir.exists(): raise ValueError(f'{"Training" if train else "Test"} directory not found: {data_dir}') return datasets.ImageFolder( root=data_dir, transform=transform ) else: # PyTorch dataset class (CIFAR-10, MNIST, etc.) dataset_class = data_source return dataset_class( train=train, download=download, transform=transform, **dataset_kwargs )
[docs] def prepare_splits( train_dataset: Dataset, test_dataset: Dataset | None = None, val_size: int = 10000, test_size: int | None = None ) -> Tuple[Dataset, Dataset, Dataset]: '''Split training dataset into train/val(/test) splits. The splitting behavior depends on whether a separate test dataset is provided: - If test_dataset is provided: Split train_dataset into train/val only (2-way split) - If test_dataset is None: Split train_dataset into train/val/test (3-way split) Args: train_dataset: Training dataset to split test_dataset: Test dataset. If None, test set will be split from train_dataset. val_size: Number of images to use for validation test_size: Number of images to reserve for testing when test_dataset is None. Only used when test_dataset is None. If None when test_dataset is None, raises ValueError. Returns: Tuple of (train_dataset, val_dataset, test_dataset) Examples: # 2-way split: Pass separate test set train_ds, val_ds, test_ds = prepare_splits( train_dataset=my_train_data, test_dataset=my_test_data, # Use this for testing val_size=10000 # 10,000 images for validation ) # 3-way split: No separate test set train_ds, val_ds, test_ds = prepare_splits( train_dataset=my_full_data, test_dataset=None, # Will split test from train_dataset val_size=10000, # 10,000 for validation test_size=5000 # 5,000 for testing ) ''' if test_dataset is not None: # 2-way split: train/val only, use provided test set total_size = len(train_dataset) if val_size >= total_size: raise ValueError(f'val_size ({val_size}) must be less than train_dataset size ({total_size})') indices = torch.randperm(total_size).tolist() val_indices = indices[:val_size] train_indices = indices[val_size:] train_dataset_final = Subset(train_dataset, train_indices) val_dataset_final = Subset(train_dataset, val_indices) test_dataset_final = test_dataset else: # 3-way split: train/val/test all from train_dataset if test_size is None: raise ValueError('test_size must be provided when test_dataset is None') total_size = len(train_dataset) if val_size + test_size >= total_size: raise ValueError( f'val_size ({val_size}) + test_size ({test_size}) must be less than ' f'train_dataset size ({total_size})' ) indices = torch.randperm(total_size).tolist() val_indices = indices[:val_size] test_indices = indices[val_size:val_size + test_size] train_indices = indices[val_size + test_size:] train_dataset_final = Subset(train_dataset, train_indices) val_dataset_final = Subset(train_dataset, val_indices) test_dataset_final = Subset(train_dataset, test_indices) return train_dataset_final, val_dataset_final, test_dataset_final
[docs] def create_dataloaders( train_dataset: Dataset, val_dataset: Dataset, test_dataset: Dataset, batch_size: int, shuffle_train: bool = True, num_workers: int = 0, preload_to_memory: bool = True, device: torch.device | None = None, **kwargs ) -> Tuple[DataLoader, DataLoader, DataLoader]: '''Create DataLoaders from prepared datasets with optional memory preloading. This function provides three memory management strategies: 1. Lazy loading (preload_to_memory=False): Data stays on disk, loaded per batch 2. CPU preloading (preload_to_memory=True, device=cpu): Entire dataset in RAM 3. GPU preloading (preload_to_memory=True, device=cuda): Entire dataset in VRAM Args: train_dataset: Prepared training dataset val_dataset: Prepared validation dataset test_dataset: Prepared test dataset batch_size: Batch size for all DataLoaders shuffle_train: Whether to shuffle training data (default: True) num_workers: Number of subprocesses for data loading (default: 0 for single process). Note: num_workers is ignored when preload_to_memory=True. preload_to_memory: If True, convert datasets to tensors and load into memory. If False, keep as lazy-loading Dataset objects (default: True). device: Device to preload tensors onto. Only used if preload_to_memory=True. If None with preload_to_memory=True, defaults to CPU. Common values: torch.device('cpu'), torch.device('cuda') **kwargs: Additional keyword arguments passed to DataLoader (e.g., pin_memory=True, persistent_workers=True) Returns: Tuple of (train_loader, val_loader, test_loader) Examples: # Strategy 1: Lazy loading (large datasets) train_loader, val_loader, test_loader = create_dataloaders( train_ds, val_ds, test_ds, batch_size=128, num_workers=4, pin_memory=True ) # Strategy 2: CPU preloading (medium datasets) train_loader, val_loader, test_loader = create_dataloaders( train_ds, val_ds, test_ds, batch_size=128, preload_to_memory=True, device=torch.device('cpu') ) # Strategy 3: GPU preloading (small datasets, fastest training) train_loader, val_loader, test_loader = create_dataloaders( train_ds, val_ds, test_ds, batch_size=128, preload_to_memory=True, device=torch.device('cuda') ) ''' if preload_to_memory: # Preload datasets to memory if device is None: device = torch.device('cpu') # Load train data X_train = torch.stack([img for img, _ in train_dataset]).to(device) y_train = torch.tensor([label for _, label in train_dataset]).to(device) train_dataset_final = TensorDataset(X_train, y_train) # Load val data X_val = torch.stack([img for img, _ in val_dataset]).to(device) y_val = torch.tensor([label for _, label in val_dataset]).to(device) val_dataset_final = TensorDataset(X_val, y_val) # Load test data X_test = torch.stack([img for img, _ in test_dataset]).to(device) y_test = torch.tensor([label for _, label in test_dataset]).to(device) test_dataset_final = TensorDataset(X_test, y_test) # When preloading, num_workers should be 0 num_workers = 0 else: # Use datasets as-is for lazy loading train_dataset_final = train_dataset val_dataset_final = val_dataset test_dataset_final = test_dataset train_loader = DataLoader( train_dataset_final, batch_size=batch_size, shuffle=shuffle_train, num_workers=num_workers, **kwargs ) val_loader = DataLoader( val_dataset_final, batch_size=batch_size, shuffle=False, num_workers=num_workers, **kwargs ) test_loader = DataLoader( test_dataset_final, batch_size=batch_size, shuffle=False, num_workers=num_workers, **kwargs ) return train_loader, val_loader, test_loader
[docs] def generate_augmented_data( train_dataset: Dataset, augmentation_transforms: torch.nn.Sequential, augmentations_per_image: int, save_dir: str | Path, class_names: list[str] | None = None, chunk_size: int = 5000, force_reaugment: bool = False ) -> None: '''Generate augmented training data and save as ImageFolder-compatible directory structure. This function applies augmentation transforms to create multiple augmented versions of each training image and saves them to disk in ImageFolder format. Images are processed in chunks to avoid memory issues with large datasets. Args: train_dataset: PyTorch Dataset containing training images augmentation_transforms: nn.Sequential containing augmentation transforms to apply augmentations_per_image: Number of augmented versions to create per image save_dir: Directory path to save augmented images in ImageFolder format (will create class_0/, class_1/, etc. subdirectories) class_names: Optional list of class names. If None, uses numeric class indices. chunk_size: Number of images to process per chunk (default: 5000) force_reaugment: If True, regenerate even if saved data exists Returns: None (saves images to disk) Example: >>> generate_augmented_data( ... train_dataset=train_dataset, ... augmentation_transforms=augmentation_transforms, ... augmentations_per_image=3, ... save_dir='data/cifar10_augmented', ... class_names=['airplane', 'automobile', ...], ... chunk_size=5000 ... ) >>> # Then load with existing pipeline: >>> aug_dataset, _ = load_datasets( ... data_source=Path('data/cifar10_augmented'), ... transform=eval_transform ... ) ''' save_dir = Path(save_dir) # Check if data already exists if save_dir.exists() and any(save_dir.iterdir()) and not force_reaugment: print(f'Augmented data already exists at {save_dir}') print('Use force_reaugment=True to regenerate') return # Create directory structure with 'train' subdirectory for ImageFolder compatibility if force_reaugment and save_dir.exists(): shutil.rmtree(save_dir) save_dir.mkdir(parents=True, exist_ok=True) train_dir = save_dir / 'train' train_dir.mkdir(exist_ok=True) # Get unique classes from dataset print('Scanning dataset for classes...') all_labels = set() for _, label in train_dataset: all_labels.add(label if isinstance(label, int) else label.item()) unique_classes = sorted(all_labels) # Create class directories inside train/ for class_idx in unique_classes: if class_names and class_idx < len(class_names): class_dir = train_dir / class_names[class_idx] else: class_dir = train_dir / f'class_{class_idx}' class_dir.mkdir(exist_ok=True) print(f'Found {len(unique_classes)} classes') print(f'Saving augmented images to {save_dir}') original_size = len(train_dataset) num_chunks = (original_size + chunk_size - 1) // chunk_size print(f'Processing {original_size} images in {num_chunks} chunk(s)') print(f'Generating {augmentations_per_image} augmentations per image') total_saved = 0 # Process dataset in chunks for chunk_idx in range(num_chunks): start_idx = chunk_idx * chunk_size end_idx = min((chunk_idx + 1) * chunk_size, original_size) print(f'\nChunk {chunk_idx + 1}/{num_chunks} (images {start_idx}-{end_idx-1})...') # Process each image in chunk for idx in range(start_idx, end_idx): img, label = train_dataset[idx] label_val = label if isinstance(label, int) else label.item() # Determine class directory (inside train/) if class_names and label_val < len(class_names): class_dir = train_dir / class_names[label_val] else: class_dir = train_dir / f'class_{label_val}' # Save original image img_name = f'img_{idx:06d}_orig.png' save_image(img, class_dir / img_name) total_saved += 1 # Generate and save augmented versions for aug_idx in range(augmentations_per_image): img_aug = augmentation_transforms(img.unsqueeze(0)).squeeze(0) img_name = f'img_{idx:06d}_aug{aug_idx:02d}.png' save_image(img_aug, class_dir / img_name) total_saved += 1 print(f' Chunk {chunk_idx + 1} complete') print(f'\nAugmentation complete!') print(f' Total images saved: {total_saved}') print(f' Original images: {original_size}') print(f' Augmented images: {total_saved - original_size}') print(f' Augmentation factor: {total_saved / original_size:.1f}x') print(f' Location: {save_dir}')