mini.trainer¶
mini.trainer is a minimalist PyTorch training loop that centralizes the routine sequencing (state builds, accumulation, retries) while leaving distributed setup and model construction under user control.
What BaseTrainer handles¶
BaseTrainer stays intentionally small. Out of the box it:
- Builds and restores state for the data loader, model, optimizer, scheduler, and optional grad scaler in a DDP/FSDP-friendly order.
- Runs a resilient step loop that supports gradient accumulation, mixed precision, scale management, NaN/Inf guarding, and data/step timing across single or distributed devices.
- Pushes everything else into hooks. Progress output, checkpoint IO, wandb logging, CUDA memory tracking, and validation all ship as hooks, which makes it easy to compose your own.
Because nearly every behavior lives behind hooks, most real-world setups fit without forking the code. When you do need something exotic, the trainer is compact enough to copy into your project and customize.
Key capabilities¶
- Subclass contract: implement the
build_*()factories andforward(), then rely on the providedtrain()loop. - Distributed-ready: gradient accumulation, autocast, gradient scaling,
no_sync, and non-finite checks work seamlessly with PyTorch DDP and FSDP. - State management: save and restore the full training state (model, optimizer, scheduler, grad scaler, current step and hooks) with correct ordering and distributed-aware handling.
- Hook system: progress, checkpointing, logging (e.g., wandb), EMA, CUDA memory stats, and validation are all hooks you can combine or extend.
- Instrumentation coverage: step/data timings, gradient norms, learning rates, and user metrics are captured for logging hooks to consume.
Quick start¶
-
Subclass
BaseTrainerand implement the required methods:from mini.trainer import BaseTrainer import torch import torch.nn as nn class MyTrainer(BaseTrainer): def build_data_loader(self): # Return any iterable (DataLoader, generator, etc.) dataset = ... return torch.utils.data.DataLoader(dataset, batch_size=32) def build_model(self): model = nn.Sequential( nn.Linear(784, 128), nn.ReLU(), nn.Linear(128, 10), ) return model.to(self.device) def build_optimizer(self): return torch.optim.Adam(self.model.parameters(), lr=1e-3) def forward(self, input): # Implement your forward pass and loss computation x, y = input # comes from data loader x, y = x.to(self.device), y.to(self.device) logits = self.model(x) loss = nn.functional.cross_entropy(logits, y) # Return loss and a dict of metrics to log records = {"accuracy": (logits.argmax(1) == y).float().mean().item()} return loss, records -
Create and train:
-
Add hooks for logging, checkpointing, and more:
from mini.trainer import BaseTrainer, ProgressHook, CheckpointingHook, LoggingHook class MyTrainer(BaseTrainer): # ... (same as above) def build_hooks(self): return [ ProgressHook(interval=10, with_records=True), CheckpointingHook(interval=1000, path=self.workspace / "checkpoints"), LoggingHook(interval=100), WandbHook(project="my-project"), # requires LoggingHook ]
Why mini.trainer¶
- Lightning takes a framework-style approach: the user defines modules and callbacks while Lightning constructs the training loop, manages distributed wrappers, and coordinates evaluation. This removes boilerplate but makes it harder to diverge from the frameworkâs lifecycle or debug lower-level issues when the defaults do not match your setup.
- Hugging Face Accelerate sits in the middle: you still write the loop, yet the library configures devices, mixed precision, and distributed collectives through helper objects. That consolidation keeps the API uniform across hardware, but the indirection can obscure which PyTorch primitives run at each stage. Debugging becomes hard, as documentation is sparse and the code is complex.
- mini.trainer assumes the complementary responsibilities. You hold on to device setup, process groups, and any model wrapping, while the trainer focuses on sequencing: it builds components in the correct order, runs the step loop with accumulation and retry policies, and surfaces metrics to hooks. The implementation stays small enough to audit or copy when a project needs to diverge.
Core concepts¶
The training loop¶
BaseTrainer.train() orchestrates the training process:
- Calls
_build()to construct the model, optimizer, data loader, and hooks. - Iterates over
max_steps, calling_run_step()for each step. - Handles gradient accumulation automatically.
- Invokes hooks at key points (before/after train, before/after step).
You typically don't override train() itself; implement the build_*() and forward() methods instead.
Required methods¶
Subclasses must implement:
build_data_loader() -> Iterable: Return an iterable that yields training data.build_model() -> nn.Module: Construct and return the model.build_optimizer() -> torch.optim.Optimizer: Create the optimizer.forward(input) -> tuple[torch.Tensor | None, dict]: Perform a forward pass and return(loss, records).loss: The scalar loss tensor. IfNone, the backward pass is skipped and the step is retried.records: A nested dict of numeric metrics that you want to log (e.g.,{"accuracy": 0.95, "metrics": {"f1": 0.9}}).
Optional methods¶
You can override these to customize behavior:
build_lr_scheduler() -> torch.optim.lr_scheduler.LRScheduler | None: Return a learning rate scheduler (default:None).build_grad_scaler() -> torch.amp.GradScaler: Customize the gradient scaler for mixed precision (default: enabled for FP16).build_hooks() -> list[BaseHook]: Return a list of hooks (default:[]).
Gradient accumulation¶
Set gradient_accumulation_steps to accumulate gradients over multiple forward/backward passes before updating parameters. The trainer automatically:
- Scales the loss by
1 / gradient_accumulation_steps. - Calls
model.no_sync()during accumulation steps (if using DDP/FSDP) to skip gradient synchronization until the final step.
Mixed precision¶
Pass mixed_precision="fp16" or "bf16" to enable automatic mixed precision training with torch.autocast. The trainer:
- Uses
torch.amp.GradScalerfor FP16 to handle gradient scaling. - Disables the scaler for BF16 (which doesn't need gradient scaling).
Non-finite gradient handling¶
If gradients become NaN or Inf, the trainer can retry the step:
- Set
max_non_finite_grad_retriesto a positive integer to enable retries. - The trainer will reset gradients and re-run the forward/backward pass.
- If retries are exhausted, a
RuntimeErroris raised.
State management¶
Save and load the full (possibly distributed) training state (model, optimizer, scheduler, hooks):
The CheckpointingHook automates this including saving and loading from disk for you.
Hooks¶
Hooks let you inject custom logic at key points in the training loop. All hooks inherit from BaseHook and can override:
on_before_train(trainer): called once before training starts.on_after_train(trainer): called once after training finishes.on_before_step(trainer): called before each training step.on_after_step(trainer): called after each training step.on_log(trainer, records, dry_run): called when the trainer logs metrics.on_log_images(trainer, records, dry_run): called when the trainer logs images.on_state_dict(trainer, state_dict): called when saving a checkpoint.on_load_state_dict(trainer, state_dict): called when loading a checkpoint.
Built-in hooks¶
ProgressHook¶
Prints training progress to the console.
ProgressHook(
interval=10, # log every N steps
with_records=True, # include extra metrics in the output
sync=False, # synchronize metrics across ranks
eta_warmup=10, # steps to warm up ETA calculation
)
Example output:
LoggingHook¶
Aggregates metrics and calls trainer.log(records) at regular intervals. Use this to send metrics to experiment trackers.
LoggingHook(
interval=100, # aggregate and log every N steps
sync=True, # synchronize metrics across ranks
)
Implement trainer.log() or add a hook that handles on_log (e.g. WandbHook) to handle the aggregated records.
CheckpointingHook¶
Saves checkpoints at regular intervals and handles automatic resuming.
CheckpointingHook(
interval=1000, # save every N steps
keep_previous=2, # keep the last N checkpoints
keep_interval=5000, # keep checkpoints every N steps (in addition to the last N)
path="checkpoints", # directory to save checkpoints (relative to workspace)
load="latest", # load the latest checkpoint in the workspace on startup ("latest", a specific path, or None)
exit_signals=[signal.SIGTERM, signal.SIGINT], # save on these signals before exiting
exit_code="128+signal", # exit code to use after signal handling
exit_wait=60.0, # wait time before exiting (useful to get TIMEOUT instead of FAILED slurm job status)
)
Checkpoints are saved as checkpoint_step_{step}.pt, and the latest is symlinked as checkpoint_latest.pt.
CudaMaxMemoryHook¶
Tracks and logs the maximum GPU memory allocated during training.
Adds max_memory (in GiB) to trainer.step_info for other hooks to use.
EmaHook¶
Maintains an exponential moving average (EMA) of model weights.
Access the EMA model via hook.ema_model.
WandbHook¶
Logs metrics and images to Weights & Biases.
WandbHook(
project="my-project",
name="experiment-1",
config={"lr": 1e-3, "batch_size": 32},
image_format = "png",
# ... (other wandb.init arguments)
)
Call trainer.log() and trainer.log_images() to send data to wandb.
ImageFileLoggerHook¶
Saves images to disk (useful for debugging or visualization).
Call trainer.log_images({"image_name": pil_image}) to save images.
Advanced features¶
Distributed training¶
The trainer integrates with PyTorch's DDP and FSDP. Just wrap your model with DistributedDataParallel or FullyShardedDataParallel in build_model().
Unwrapping models¶
The trainer provides a helper to unwrap compiled, distributed, or EMA-wrapped models:
unwrapped = trainer.unwrap(trainer.model)
# or use the property:
unwrapped = trainer.unwrapped_model
This is useful for accessing the base model's methods or parameters.
Custom hooks¶
Create your own hooks by subclassing BaseHook:
from mini.trainer import BaseHook
class CustomHook(BaseHook):
def on_after_step(self, trainer):
if trainer.step % 100 == 0:
print(f"Custom hook triggered at step {trainer.step}")
Add it to your trainer:
Utilities¶
map_nested_tensor¶
Apply a function to all tensors in a nested structure:
from mini.trainer import map_nested_tensor
input = {"x": torch.randn(2, 3), "y": [torch.randn(4, 5)]}
output = map_nested_tensor(lambda t: t.to("cuda"), input)
Useful for moving data to devices or converting dtypes.
key_average¶
Average numeric values across a list of nested dicts:
from mini.trainer.utils import key_average
records = [
{"loss": 0.5, "metrics": {"acc": 0.9}},
{"loss": 0.6, "metrics": {"acc": 0.85}},
]
avg = key_average(records)
# => {"loss": 0.55, "metrics": {"acc": 0.875}}
Used internally by hooks to aggregate metrics.
API reference ¶
| CLASS | DESCRIPTION |
|---|---|
BaseTrainer |
Minimal training loop that orchestrates builds, accumulation, retries, and hooks. |
BaseHook |
Lifecycle hooks for |
CheckpointingHook |
Save and optionally restore checkpoints at regular intervals. |
CudaMaxMemoryHook |
Record peak CUDA memory per step into |
LoggingHook |
Aggregate stats and forward them to |
ProgressHook |
Log progress to stdout with optional metrics, ETA, and memory. |
EmaHook |
Maintain an exponential moving average of model weights. |
WandbHook |
Log metrics and images to Weights & Biases (rank 0 only). |
ImageFileLoggerHook |
Persist logged images to |
LossNoneWarning |
Warning raised when |
| FUNCTION | DESCRIPTION |
|---|---|
map_nested_tensor |
Apply |
BaseTrainer ¶
Minimal training loop that orchestrates builds, accumulation, retries, and hooks.
Subclasses provide component factories and a forward pass; the base class handles sequencing, mixed precision, accumulation, state management, and hook dispatch.
| PARAMETER | DESCRIPTION |
|---|---|
max_steps
|
Number of training steps to run.
TYPE:
|
grad_clip
|
Max gradient norm; if set, gradients are clipped before stepping.
TYPE:
|
max_non_finite_grad_retries
|
Number of retries when encountering non-finite gradients (scaler disabled).
TYPE:
|
mixed_precision
|
TYPE:
|
gradient_accumulation_steps
|
Number of microsteps to accumulate before stepping.
TYPE:
|
workspace
|
Optional working directory used by hooks (e.g., checkpoints, logs). |
device
|
Device for the model and tensors. |
no_sync_accumulate
|
Whether to call
TYPE:
|
state_dict_options
|
Torch distributed checkpoint options.
TYPE:
|
logger
|
Logger instance; a default logger is created when omitted.
TYPE:
|
| METHOD | DESCRIPTION |
|---|---|
build_data_loader |
Return the training data iterator. |
build_model |
Construct and return the model. |
build_optimizer |
Create the optimizer for the model. |
build_lr_scheduler |
Optionally create a learning-rate scheduler. |
build_hooks |
Return hooks to run during training. |
build_grad_scaler |
Create the gradient scaler used for mixed precision. |
train |
Run the training loop until |
forward |
Perform a forward pass and return loss plus records for logging. |
log |
Dispatch numeric records to hooks (e.g., trackers or stdout). |
log_images |
Dispatch image records to hooks. |
build_lr_scheduler ¶
Optionally create a learning-rate scheduler.
build_grad_scaler ¶
Create the gradient scaler used for mixed precision.
forward ¶
Perform a forward pass and return loss plus records for logging.
| PARAMETER | DESCRIPTION |
|---|---|
input
|
Batch yielded by the data loader.
TYPE:
|
| RETURNS | DESCRIPTION |
|---|---|
Tensor | None
|
The loss ( |
Records
|
A nested dict of numeric metrics that will be averaged and emitted to hooks. |
log ¶
BaseHook ¶
Lifecycle hooks for BaseTrainer.
CheckpointingHook ¶
Bases: BaseHook
Save and optionally restore checkpoints at regular intervals.
| PARAMETER | DESCRIPTION |
|---|---|
interval
|
Save every
TYPE:
|
keep_previous
|
Keep the last N checkpoints in addition to the latest.
TYPE:
|
keep_interval
|
Keep checkpoints every
TYPE:
|
path
|
Directory (relative to workspace unless absolute) for checkpoints. |
load
|
Path to load at startup or
TYPE:
|
exit_signals
|
Signals that trigger a checkpoint then exit. |
exit_code
|
Exit code after handling an exit signal. |
exit_wait
|
Optional sleep before exit (useful for schedulers). |
LoggingHook ¶
ProgressHook ¶
Bases: _StatsHook
Log progress to stdout with optional metrics, ETA, and memory.
| PARAMETER | DESCRIPTION |
|---|---|
interval
|
Log every N steps.
TYPE:
|
with_records
|
Include per-step records in the log line.
TYPE:
|
sync
|
If True, aggregate across distributed ranks.
TYPE:
|
eta_warmup
|
Steps to warm up ETA calculation.
TYPE:
|
show_units
|
Whether to print units (s, GiB) alongside values.
TYPE:
|
EmaHook ¶
WandbHook ¶
Bases: BaseHook
Log metrics and images to Weights & Biases (rank 0 only).
| PARAMETER | DESCRIPTION |
|---|---|
project
|
W&B project name.
TYPE:
|
config
|
Optional config dict or JSON file path to log. |
tags
|
Optional tag list. |
image_format
|
File format for images or a callable to derive it per key.
TYPE:
|
**wandb_kwargs
|
Extra arguments forwarded to
DEFAULT:
|
ImageFileLoggerHook ¶
LossNoneWarning ¶
Bases: UserWarning
Warning raised when forward returns None in distributed contexts.
map_nested_tensor ¶
Apply f to every tensor contained in a nested structure.
Tips¶
- Use hooks for side effects: logging, checkpointing, and validation are best handled via hooks.
- Combine with
mini.configandmini.builder: define your training setup in config files and use the builder to instantiate the trainer.
Integration example¶
Using mini.trainer with mini.config and mini.builder:
# configs/train.py
config = {
"type": "MyTrainer",
"max_steps": 10_000,
"grad_clip": 1.0,
"mixed_precision": "bf16",
"gradient_accumulation_steps": 4,
"workspace": "./runs/experiment_1",
"data": {
"type": "torch.utils.data.DataLoader",
"dataset": {"type": "ToyDataset"},
"batch_size": 32,
},
"model": {
"type": "MyModel",
"in_dim": 784,
"hidden_dim": 128,
"out_dim": 10,
},
"optimizer": {
"type": "torch.optim.AdamW",
"lr": 3e-4,
},
}
# train.py
import logging
import torch
from torch import nn
from torch.utils.data import IterableDataset
from mini.builder import register, build
from mini.config import load
from mini.trainer import BaseTrainer, CheckpointingHook, ProgressHook
logging.basicConfig(level=logging.INFO)
@register()
class MyTrainer(BaseTrainer):
def __init__(self, data, model, optimizer, **kwargs):
super().__init__(**kwargs)
self.data_cfg = data
self.model_cfg = model
self.optimizer_cfg = optimizer
def build_data_loader(self):
return build(self.data_cfg)
def build_model(self):
model = build(self.model_cfg)
return model.to(self.device)
def build_optimizer(self):
return build(self.optimizer_cfg | {"params": self.model.parameters()})
def forward(self, input):
x, y = input
x, y = x.to(self.device), y.to(self.device)
logits = self.model(x)
loss = torch.nn.functional.cross_entropy(logits, y)
acc = (logits.argmax(1) == y).float().mean().item()
return loss, {"accuracy": acc}
def build_hooks(self):
return [
ProgressHook(interval=1, with_records=True),
CheckpointingHook(interval=1000, path=self.workspace / "checkpoints"),
]
@register()
class ToyDataset(IterableDataset):
def __iter__(self):
while True:
x = torch.randn(1, 28, 28)
y = torch.randint(0, 10, (1,)).item()
yield x.view(-1), y
@register()
class MyModel(nn.Module):
def __init__(self, in_dim: int, hidden_dim: int, out_dim: int):
super().__init__()
self.net = nn.Sequential(
nn.Linear(in_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, out_dim),
)
def forward(self, x):
return self.net(x)
try:
device = torch.device(0)
torch.distributed.init_process_group(
world_size=1, rank=0, store=torch.distributed.HashStore(), device_id=device
)
cfg = load("configs/train.py")
trainer = build(cfg | {"device": device}, recursive=False)
trainer.train()
finally:
torch.distributed.destroy_process_group()
This approach keeps your training code declarative and easy to modify via config files.