Skip to content

mini.trainer

mini.trainer is a minimalist PyTorch training loop that centralizes the routine sequencing (state builds, accumulation, retries) while leaving distributed setup and model construction under user control.

What BaseTrainer handles

BaseTrainer stays intentionally small. Out of the box it:

  • Builds and restores state for the data loader, model, optimizer, scheduler, and optional grad scaler in a DDP/FSDP-friendly order.
  • Runs a resilient step loop that supports gradient accumulation, mixed precision, scale management, NaN/Inf guarding, and data/step timing across single or distributed devices.
  • Pushes everything else into hooks. Progress output, checkpoint IO, wandb logging, CUDA memory tracking, and validation all ship as hooks, which makes it easy to compose your own.

Because nearly every behavior lives behind hooks, most real-world setups fit without forking the code. When you do need something exotic, the trainer is compact enough to copy into your project and customize.

Key capabilities

  • Subclass contract: implement the build_*() factories and forward(), then rely on the provided train() loop.
  • Distributed-ready: gradient accumulation, autocast, gradient scaling, no_sync, and non-finite checks work seamlessly with PyTorch DDP and FSDP.
  • State management: save and restore the full training state (model, optimizer, scheduler, grad scaler, current step and hooks) with correct ordering and distributed-aware handling.
  • Hook system: progress, checkpointing, logging (e.g., wandb), EMA, CUDA memory stats, and validation are all hooks you can combine or extend.
  • Instrumentation coverage: step/data timings, gradient norms, learning rates, and user metrics are captured for logging hooks to consume.

Quick start

  1. Subclass BaseTrainer and implement the required methods:

    from mini.trainer import BaseTrainer
    import torch
    import torch.nn as nn
    
    class MyTrainer(BaseTrainer):
        def build_data_loader(self):
            # Return any iterable (DataLoader, generator, etc.)
            dataset = ...
            return torch.utils.data.DataLoader(dataset, batch_size=32)
    
        def build_model(self):
            model = nn.Sequential(
                nn.Linear(784, 128),
                nn.ReLU(),
                nn.Linear(128, 10),
            )
            return model.to(self.device)
    
        def build_optimizer(self):
            return torch.optim.Adam(self.model.parameters(), lr=1e-3)
    
        def forward(self, input):
            # Implement your forward pass and loss computation
            x, y = input  # comes from data loader
            x, y = x.to(self.device), y.to(self.device)
            logits = self.model(x)
            loss = nn.functional.cross_entropy(logits, y)
    
            # Return loss and a dict of metrics to log
            records = {"accuracy": (logits.argmax(1) == y).float().mean().item()}
            return loss, records
    
  2. Create and train:

    trainer = MyTrainer(
        max_steps=10_000,
        grad_clip=1.0,
        max_non_finite_grad_retries=3,
        mixed_precision="bf16",  # or "fp16", or None
        gradient_accumulation_steps=4,
        workspace="./runs/experiment_1",
        device="cuda",
    )
    
    trainer.train()
    
  3. Add hooks for logging, checkpointing, and more:

    from mini.trainer import BaseTrainer, ProgressHook, CheckpointingHook, LoggingHook
    
    class MyTrainer(BaseTrainer):
        # ... (same as above)
    
        def build_hooks(self):
            return [
                ProgressHook(interval=10, with_records=True),
                CheckpointingHook(interval=1000, path=self.workspace / "checkpoints"),
                LoggingHook(interval=100),
                WandbHook(project="my-project"),  # requires LoggingHook
            ]
    

Why mini.trainer

  • Lightning takes a framework-style approach: the user defines modules and callbacks while Lightning constructs the training loop, manages distributed wrappers, and coordinates evaluation. This removes boilerplate but makes it harder to diverge from the framework’s lifecycle or debug lower-level issues when the defaults do not match your setup.
  • Hugging Face Accelerate sits in the middle: you still write the loop, yet the library configures devices, mixed precision, and distributed collectives through helper objects. That consolidation keeps the API uniform across hardware, but the indirection can obscure which PyTorch primitives run at each stage. Debugging becomes hard, as documentation is sparse and the code is complex.
  • mini.trainer assumes the complementary responsibilities. You hold on to device setup, process groups, and any model wrapping, while the trainer focuses on sequencing: it builds components in the correct order, runs the step loop with accumulation and retry policies, and surfaces metrics to hooks. The implementation stays small enough to audit or copy when a project needs to diverge.

Core concepts

The training loop

BaseTrainer.train() orchestrates the training process:

  1. Calls _build() to construct the model, optimizer, data loader, and hooks.
  2. Iterates over max_steps, calling _run_step() for each step.
  3. Handles gradient accumulation automatically.
  4. Invokes hooks at key points (before/after train, before/after step).

You typically don't override train() itself; implement the build_*() and forward() methods instead.

Required methods

Subclasses must implement:

  • build_data_loader() -> Iterable: Return an iterable that yields training data.
  • build_model() -> nn.Module: Construct and return the model.
  • build_optimizer() -> torch.optim.Optimizer: Create the optimizer.
  • forward(input) -> tuple[torch.Tensor | None, dict]: Perform a forward pass and return (loss, records).
  • loss: The scalar loss tensor. If None, the backward pass is skipped and the step is retried.
  • records: A nested dict of numeric metrics that you want to log (e.g., {"accuracy": 0.95, "metrics": {"f1": 0.9}}).

Optional methods

You can override these to customize behavior:

  • build_lr_scheduler() -> torch.optim.lr_scheduler.LRScheduler | None: Return a learning rate scheduler (default: None).
  • build_grad_scaler() -> torch.amp.GradScaler: Customize the gradient scaler for mixed precision (default: enabled for FP16).
  • build_hooks() -> list[BaseHook]: Return a list of hooks (default: []).

Gradient accumulation

Set gradient_accumulation_steps to accumulate gradients over multiple forward/backward passes before updating parameters. The trainer automatically:

  • Scales the loss by 1 / gradient_accumulation_steps.
  • Calls model.no_sync() during accumulation steps (if using DDP/FSDP) to skip gradient synchronization until the final step.

Mixed precision

Pass mixed_precision="fp16" or "bf16" to enable automatic mixed precision training with torch.autocast. The trainer:

  • Uses torch.amp.GradScaler for FP16 to handle gradient scaling.
  • Disables the scaler for BF16 (which doesn't need gradient scaling).

Non-finite gradient handling

If gradients become NaN or Inf, the trainer can retry the step:

  • Set max_non_finite_grad_retries to a positive integer to enable retries.
  • The trainer will reset gradients and re-run the forward/backward pass.
  • If retries are exhausted, a RuntimeError is raised.

State management

Save and load the full (possibly distributed) training state (model, optimizer, scheduler, hooks):

# Save checkpoint
state = trainer.state_dict()

# Resume training
trainer.load_state_dict(state)

The CheckpointingHook automates this including saving and loading from disk for you.


Hooks

Hooks let you inject custom logic at key points in the training loop. All hooks inherit from BaseHook and can override:

  • on_before_train(trainer): called once before training starts.
  • on_after_train(trainer): called once after training finishes.
  • on_before_step(trainer): called before each training step.
  • on_after_step(trainer): called after each training step.
  • on_log(trainer, records, dry_run): called when the trainer logs metrics.
  • on_log_images(trainer, records, dry_run): called when the trainer logs images.
  • on_state_dict(trainer, state_dict): called when saving a checkpoint.
  • on_load_state_dict(trainer, state_dict): called when loading a checkpoint.

Built-in hooks

ProgressHook

Prints training progress to the console.

ProgressHook(
    interval=10,          # log every N steps
    with_records=True,    # include extra metrics in the output
    sync=False,           # synchronize metrics across ranks
    eta_warmup=10,        # steps to warm up ETA calculation
)

Example output:

Step   100/10000: step 0.2500s data 0.0100s eta 6:00:00 loss 0.5234 grad_norm 2.3456 lr_0 1.00e-03

LoggingHook

Aggregates metrics and calls trainer.log(records) at regular intervals. Use this to send metrics to experiment trackers.

LoggingHook(
    interval=100,  # aggregate and log every N steps
    sync=True,     # synchronize metrics across ranks
)

Implement trainer.log() or add a hook that handles on_log (e.g. WandbHook) to handle the aggregated records.

CheckpointingHook

Saves checkpoints at regular intervals and handles automatic resuming.

CheckpointingHook(
    interval=1000,              # save every N steps
    keep_previous=2,            # keep the last N checkpoints
    keep_interval=5000,         # keep checkpoints every N steps (in addition to the last N)
    path="checkpoints",         # directory to save checkpoints (relative to workspace)
    load="latest",              # load the latest checkpoint in the workspace on startup ("latest", a specific path, or None)
    exit_signals=[signal.SIGTERM, signal.SIGINT],  # save on these signals before exiting
    exit_code="128+signal",     # exit code to use after signal handling
    exit_wait=60.0,             # wait time before exiting (useful to get TIMEOUT instead of FAILED slurm job status)
)

Checkpoints are saved as checkpoint_step_{step}.pt, and the latest is symlinked as checkpoint_latest.pt.

CudaMaxMemoryHook

Tracks and logs the maximum GPU memory allocated during training.

CudaMaxMemoryHook()

Adds max_memory (in GiB) to trainer.step_info for other hooks to use.

EmaHook

Maintains an exponential moving average (EMA) of model weights.

EmaHook(
    decay=0.9999,               # EMA decay rate
)

Access the EMA model via hook.ema_model.

WandbHook

Logs metrics and images to Weights & Biases.

WandbHook(
    project="my-project",
    name="experiment-1",
    config={"lr": 1e-3, "batch_size": 32},
    image_format = "png",
    # ... (other wandb.init arguments)
)

Call trainer.log() and trainer.log_images() to send data to wandb.

ImageFileLoggerHook

Saves images to disk (useful for debugging or visualization).

ImageFileLoggerHook(
    image_format = "png",
)

Call trainer.log_images({"image_name": pil_image}) to save images.


Advanced features

Distributed training

The trainer integrates with PyTorch's DDP and FSDP. Just wrap your model with DistributedDataParallel or FullyShardedDataParallel in build_model().

Unwrapping models

The trainer provides a helper to unwrap compiled, distributed, or EMA-wrapped models:

unwrapped = trainer.unwrap(trainer.model)
# or use the property:
unwrapped = trainer.unwrapped_model

This is useful for accessing the base model's methods or parameters.

Custom hooks

Create your own hooks by subclassing BaseHook:

from mini.trainer import BaseHook

class CustomHook(BaseHook):
    def on_after_step(self, trainer):
        if trainer.step % 100 == 0:
            print(f"Custom hook triggered at step {trainer.step}")

Add it to your trainer:

def build_hooks(self):
    return [CustomHook(), ProgressHook(interval=10)]

Utilities

map_nested_tensor

Apply a function to all tensors in a nested structure:

from mini.trainer import map_nested_tensor

input = {"x": torch.randn(2, 3), "y": [torch.randn(4, 5)]}
output = map_nested_tensor(lambda t: t.to("cuda"), input)

Useful for moving data to devices or converting dtypes.

key_average

Average numeric values across a list of nested dicts:

from mini.trainer.utils import key_average

records = [
    {"loss": 0.5, "metrics": {"acc": 0.9}},
    {"loss": 0.6, "metrics": {"acc": 0.85}},
]
avg = key_average(records)
# => {"loss": 0.55, "metrics": {"acc": 0.875}}

Used internally by hooks to aggregate metrics.


API reference

CLASS DESCRIPTION
BaseTrainer

Minimal training loop that orchestrates builds, accumulation, retries, and hooks.

BaseHook

Lifecycle hooks for BaseTrainer.

CheckpointingHook

Save and optionally restore checkpoints at regular intervals.

CudaMaxMemoryHook

Record peak CUDA memory per step into trainer.step_info.

LoggingHook

Aggregate stats and forward them to trainer.log.

ProgressHook

Log progress to stdout with optional metrics, ETA, and memory.

EmaHook

Maintain an exponential moving average of model weights.

WandbHook

Log metrics and images to Weights & Biases (rank 0 only).

ImageFileLoggerHook

Persist logged images to workspace/visualizations on rank 0.

LossNoneWarning

Warning raised when forward returns None in distributed contexts.

FUNCTION DESCRIPTION
map_nested_tensor

Apply f to every tensor contained in a nested structure.

BaseTrainer

Minimal training loop that orchestrates builds, accumulation, retries, and hooks.

Subclasses provide component factories and a forward pass; the base class handles sequencing, mixed precision, accumulation, state management, and hook dispatch.

PARAMETER DESCRIPTION
max_steps

Number of training steps to run.

TYPE: int

grad_clip

Max gradient norm; if set, gradients are clipped before stepping.

TYPE: float | None DEFAULT: None

max_non_finite_grad_retries

Number of retries when encountering non-finite gradients (scaler disabled).

TYPE: int | None DEFAULT: None

mixed_precision

"fp16" or "bf16" to enable autocast; None disables it.

TYPE: str | None DEFAULT: None

gradient_accumulation_steps

Number of microsteps to accumulate before stepping.

TYPE: int | None DEFAULT: None

workspace

Optional working directory used by hooks (e.g., checkpoints, logs).

TYPE: Path | str | None DEFAULT: None

device

Device for the model and tensors.

TYPE: device | str | int | None DEFAULT: None

no_sync_accumulate

Whether to call no_sync on distributed modules during accumulation.

TYPE: bool DEFAULT: True

state_dict_options

Torch distributed checkpoint options.

TYPE: StateDictOptions | None DEFAULT: None

logger

Logger instance; a default logger is created when omitted.

TYPE: Logger | None DEFAULT: None

METHOD DESCRIPTION
build_data_loader

Return the training data iterator.

build_model

Construct and return the model.

build_optimizer

Create the optimizer for the model.

build_lr_scheduler

Optionally create a learning-rate scheduler.

build_hooks

Return hooks to run during training.

build_grad_scaler

Create the gradient scaler used for mixed precision.

train

Run the training loop until max_steps are completed.

forward

Perform a forward pass and return loss plus records for logging.

log

Dispatch numeric records to hooks (e.g., trackers or stdout).

log_images

Dispatch image records to hooks.

build_data_loader

build_data_loader() -> Iterable

Return the training data iterator.

build_model

build_model() -> nn.Module

Construct and return the model.

build_optimizer

build_optimizer() -> torch.optim.Optimizer

Create the optimizer for the model.

build_lr_scheduler

build_lr_scheduler() -> torch.optim.lr_scheduler.LRScheduler | None

Optionally create a learning-rate scheduler.

build_hooks

build_hooks() -> list[BaseHook]

Return hooks to run during training.

build_grad_scaler

build_grad_scaler() -> torch.amp.GradScaler

Create the gradient scaler used for mixed precision.

train

train()

Run the training loop until max_steps are completed.

forward

forward(input: Any) -> tuple[torch.Tensor | None, Records]

Perform a forward pass and return loss plus records for logging.

PARAMETER DESCRIPTION
input

Batch yielded by the data loader.

TYPE: Any

RETURNS DESCRIPTION
Tensor | None

The loss (None skips backward/step; if using DDP/FSDP, avoid invoking the wrapped module's forward in that case).

Records

A nested dict of numeric metrics that will be averaged and emitted to hooks.

log

log(records: dict[str, Any], dry_run: bool = False)

Dispatch numeric records to hooks (e.g., trackers or stdout).

PARAMETER DESCRIPTION
records

Nested dict of numeric metrics to log.

TYPE: dict[str, Any]

dry_run

If True, hooks should avoid side effects and only report intent.

TYPE: bool DEFAULT: False

log_images

log_images(records: dict[str, Any], dry_run: bool = False)

Dispatch image records to hooks.

PARAMETER DESCRIPTION
records

Nested dict of images to log.

TYPE: dict[str, Any]

dry_run

If True, hooks should avoid side effects and only report intent.

TYPE: bool DEFAULT: False

BaseHook

Lifecycle hooks for BaseTrainer.

CheckpointingHook

Bases: BaseHook

Save and optionally restore checkpoints at regular intervals.

PARAMETER DESCRIPTION
interval

Save every interval steps.

TYPE: int

keep_previous

Keep the last N checkpoints in addition to the latest.

TYPE: int DEFAULT: 0

keep_interval

Keep checkpoints every keep_interval steps.

TYPE: int DEFAULT: 0

path

Directory (relative to workspace unless absolute) for checkpoints.

TYPE: Path | str DEFAULT: 'checkpoint'

load

Path to load at startup or "latest" to auto-resume.

TYPE: Path | str | Literal['latest'] | None DEFAULT: 'latest'

exit_signals

Signals that trigger a checkpoint then exit.

TYPE: list[Signals] | Signals DEFAULT: None

exit_code

Exit code after handling an exit signal.

TYPE: int | Literal['128+signal'] DEFAULT: '128+signal'

exit_wait

Optional sleep before exit (useful for schedulers).

TYPE: timedelta | float DEFAULT: 0.0

CudaMaxMemoryHook

Bases: BaseHook

Record peak CUDA memory per step into trainer.step_info.

LoggingHook

Bases: _StatsHook

Aggregate stats and forward them to trainer.log.

PARAMETER DESCRIPTION
interval

Log every N steps.

TYPE: int DEFAULT: 10

sync

If True, aggregate across distributed ranks.

TYPE: bool DEFAULT: True

ProgressHook

Bases: _StatsHook

Log progress to stdout with optional metrics, ETA, and memory.

PARAMETER DESCRIPTION
interval

Log every N steps.

TYPE: int DEFAULT: 1

with_records

Include per-step records in the log line.

TYPE: bool DEFAULT: False

sync

If True, aggregate across distributed ranks.

TYPE: bool DEFAULT: False

eta_warmup

Steps to warm up ETA calculation.

TYPE: int DEFAULT: 10

show_units

Whether to print units (s, GiB) alongside values.

TYPE: bool DEFAULT: True

EmaHook

Bases: BaseHook

Maintain an exponential moving average of model weights.

PARAMETER DESCRIPTION
decay

EMA decay rate.

TYPE: float

WandbHook

Bases: BaseHook

Log metrics and images to Weights & Biases (rank 0 only).

PARAMETER DESCRIPTION
project

W&B project name.

TYPE: str

config

Optional config dict or JSON file path to log.

TYPE: dict[str, Any] | str | None DEFAULT: None

tags

Optional tag list.

TYPE: Sequence[str] | None DEFAULT: None

image_format

File format for images or a callable to derive it per key.

TYPE: str | None | Callable[[str], str | None] DEFAULT: 'png'

**wandb_kwargs

Extra arguments forwarded to wandb.init.

DEFAULT: {}

ImageFileLoggerHook

Bases: BaseHook

Persist logged images to workspace/visualizations on rank 0.

PARAMETER DESCRIPTION
image_format

File extension or callable taking the leaf key.

TYPE: str | Callable[[str], str] DEFAULT: 'png'

LossNoneWarning

Bases: UserWarning

Warning raised when forward returns None in distributed contexts.

map_nested_tensor

map_nested_tensor(f: Callable[[Tensor], Any], obj: Any)

Apply f to every tensor contained in a nested structure.


Tips

  • Use hooks for side effects: logging, checkpointing, and validation are best handled via hooks.
  • Combine with mini.config and mini.builder: define your training setup in config files and use the builder to instantiate the trainer.

Integration example

Using mini.trainer with mini.config and mini.builder:

# configs/train.py
config = {
    "type": "MyTrainer",
    "max_steps": 10_000,
    "grad_clip": 1.0,
    "mixed_precision": "bf16",
    "gradient_accumulation_steps": 4,
    "workspace": "./runs/experiment_1",
    "data": {
        "type": "torch.utils.data.DataLoader",
        "dataset": {"type": "ToyDataset"},
        "batch_size": 32,
    },
    "model": {
        "type": "MyModel",
        "in_dim": 784,
        "hidden_dim": 128,
        "out_dim": 10,
    },
    "optimizer": {
        "type": "torch.optim.AdamW",
        "lr": 3e-4,
    },
}
# train.py
import logging

import torch
from torch import nn
from torch.utils.data import IterableDataset

from mini.builder import register, build
from mini.config import load
from mini.trainer import BaseTrainer, CheckpointingHook, ProgressHook

logging.basicConfig(level=logging.INFO)


@register()
class MyTrainer(BaseTrainer):
    def __init__(self, data, model, optimizer, **kwargs):
        super().__init__(**kwargs)
        self.data_cfg = data
        self.model_cfg = model
        self.optimizer_cfg = optimizer

    def build_data_loader(self):
        return build(self.data_cfg)

    def build_model(self):
        model = build(self.model_cfg)
        return model.to(self.device)

    def build_optimizer(self):
        return build(self.optimizer_cfg | {"params": self.model.parameters()})

    def forward(self, input):
        x, y = input
        x, y = x.to(self.device), y.to(self.device)
        logits = self.model(x)
        loss = torch.nn.functional.cross_entropy(logits, y)
        acc = (logits.argmax(1) == y).float().mean().item()
        return loss, {"accuracy": acc}

    def build_hooks(self):
        return [
            ProgressHook(interval=1, with_records=True),
            CheckpointingHook(interval=1000, path=self.workspace / "checkpoints"),
        ]


@register()
class ToyDataset(IterableDataset):
    def __iter__(self):
        while True:
            x = torch.randn(1, 28, 28)
            y = torch.randint(0, 10, (1,)).item()
            yield x.view(-1), y


@register()
class MyModel(nn.Module):
    def __init__(self, in_dim: int, hidden_dim: int, out_dim: int):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, out_dim),
        )

    def forward(self, x):
        return self.net(x)


try:
    device = torch.device(0)
    torch.distributed.init_process_group(
        world_size=1, rank=0, store=torch.distributed.HashStore(), device_id=device
    )

    cfg = load("configs/train.py")
    trainer = build(cfg | {"device": device}, recursive=False)
    trainer.train()
finally:
    torch.distributed.destroy_process_group()

This approach keeps your training code declarative and easy to modify via config files.