API reference¶
trainloop ¶
| CLASS | DESCRIPTION |
|---|---|
BaseTrainer |
Minimal training loop that orchestrates builds, accumulation, retries, and hooks. |
BaseHook |
Lifecycle hooks for |
CheckpointingHook |
Save and optionally restore checkpoints at regular intervals. |
CudaMaxMemoryHook |
Record peak CUDA memory per step into |
LoggingHook |
Aggregate stats and forward them to |
ProgressHook |
Log progress to stdout with optional metrics, ETA, and memory. |
EmaHook |
Maintain an exponential moving average of model weights. |
WandbHook |
Log metrics and images to Weights & Biases (rank 0 only). |
ImageFileLoggerHook |
Persist logged images to |
LossNoneWarning |
Warning raised when |
| FUNCTION | DESCRIPTION |
|---|---|
map_nested_tensor |
Apply |
BaseTrainer ¶
Minimal training loop that orchestrates builds, accumulation, retries, and hooks.
Subclasses provide component factories and a forward pass; the base class handles sequencing, mixed precision, accumulation, state management, and hook dispatch.
| PARAMETER | DESCRIPTION |
|---|---|
max_steps
|
Number of training steps to run.
TYPE:
|
grad_clip
|
Max gradient norm; if set, gradients are clipped before stepping.
TYPE:
|
max_non_finite_grad_retries
|
Number of retries when encountering non-finite gradients (scaler disabled).
TYPE:
|
mixed_precision
|
TYPE:
|
gradient_accumulation_steps
|
Number of microsteps to accumulate before stepping.
TYPE:
|
workspace
|
Optional working directory used by hooks (e.g., checkpoints, logs). |
device
|
Device for the model and tensors. |
no_sync_accumulate
|
Whether to call
TYPE:
|
state_dict_options
|
Torch distributed checkpoint options.
TYPE:
|
logger
|
Logger instance; a default logger is created when omitted.
TYPE:
|
| METHOD | DESCRIPTION |
|---|---|
build_data_loader |
Return the training data iterator. |
build_model |
Construct and return the model. |
build_optimizer |
Create the optimizer for the model. |
build_lr_scheduler |
Optionally create a learning-rate scheduler. |
build_hooks |
Return hooks to run during training. |
build_grad_scaler |
Create the gradient scaler used for mixed precision. |
train |
Run the training loop until |
forward |
Perform a forward pass and return loss plus records for logging. |
log |
Dispatch numeric records to hooks (e.g., trackers or stdout). |
log_images |
Dispatch image records to hooks. |
build_lr_scheduler ¶
Optionally create a learning-rate scheduler.
build_grad_scaler ¶
Create the gradient scaler used for mixed precision.
forward ¶
Perform a forward pass and return loss plus records for logging.
| PARAMETER | DESCRIPTION |
|---|---|
input
|
Batch yielded by the data loader.
TYPE:
|
| RETURNS | DESCRIPTION |
|---|---|
Tensor | None
|
The loss ( |
Records
|
A nested dict of numeric metrics that will be averaged and emitted to hooks. |
log ¶
BaseHook ¶
Lifecycle hooks for BaseTrainer.
CheckpointingHook ¶
Bases: BaseHook
Save and optionally restore checkpoints at regular intervals.
| PARAMETER | DESCRIPTION |
|---|---|
interval
|
Save every
TYPE:
|
keep_previous
|
Keep the last N checkpoints in addition to the latest.
TYPE:
|
keep_interval
|
Keep checkpoints every
TYPE:
|
path
|
Directory (relative to workspace unless absolute) for checkpoints. |
load
|
Path to load at startup or
TYPE:
|
exit_signals
|
Signals that trigger a checkpoint then exit. |
exit_code
|
Exit code after handling an exit signal. |
exit_wait
|
Optional sleep before exit (useful for schedulers). |
LoggingHook ¶
ProgressHook ¶
Bases: _StatsHook
Log progress to stdout with optional metrics, ETA, and memory.
| PARAMETER | DESCRIPTION |
|---|---|
interval
|
Log every N steps.
TYPE:
|
with_records
|
Include per-step records in the log line.
TYPE:
|
sync
|
If True, aggregate across distributed ranks.
TYPE:
|
eta_warmup
|
Steps to warm up ETA calculation.
TYPE:
|
show_units
|
Whether to print units (s, GiB) alongside values.
TYPE:
|
EmaHook ¶
WandbHook ¶
Bases: BaseHook
Log metrics and images to Weights & Biases (rank 0 only).
| PARAMETER | DESCRIPTION |
|---|---|
project
|
W&B project name.
TYPE:
|
config
|
Optional config dict or JSON file path to log. |
tags
|
Optional tag list. |
image_format
|
File format for images or a callable to derive it per key.
TYPE:
|
**wandb_kwargs
|
Extra arguments forwarded to
DEFAULT:
|
ImageFileLoggerHook ¶
LossNoneWarning ¶
Bases: UserWarning
Warning raised when forward returns None in distributed contexts.
map_nested_tensor ¶
Apply f to every tensor contained in a nested structure.