Skip to content

API reference

trainloop

CLASS DESCRIPTION
BaseTrainer

Minimal training loop that orchestrates builds, accumulation, retries, and hooks.

BaseHook

Lifecycle hooks for BaseTrainer.

CheckpointingHook

Save and optionally restore checkpoints at regular intervals.

CudaMaxMemoryHook

Record peak CUDA memory per step into trainer.step_info.

LoggingHook

Aggregate stats and forward them to trainer.log.

ProgressHook

Log progress to stdout with optional metrics, ETA, and memory.

EmaHook

Maintain an exponential moving average of model weights.

WandbHook

Log metrics and images to Weights & Biases (rank 0 only).

ImageFileLoggerHook

Persist logged images to workspace/visualizations on rank 0.

LossNoneWarning

Warning raised when forward returns None in distributed contexts.

FUNCTION DESCRIPTION
map_nested_tensor

Apply f to every tensor contained in a nested structure.

BaseTrainer

Minimal training loop that orchestrates builds, accumulation, retries, and hooks.

Subclasses provide component factories and a forward pass; the base class handles sequencing, mixed precision, accumulation, state management, and hook dispatch.

PARAMETER DESCRIPTION
max_steps

Number of training steps to run.

TYPE: int

grad_clip

Max gradient norm; if set, gradients are clipped before stepping.

TYPE: float | None DEFAULT: None

max_non_finite_grad_retries

Number of retries when encountering non-finite gradients (scaler disabled).

TYPE: int | None DEFAULT: None

mixed_precision

"fp16" or "bf16" to enable autocast; None disables it.

TYPE: str | None DEFAULT: None

gradient_accumulation_steps

Number of microsteps to accumulate before stepping.

TYPE: int | None DEFAULT: None

workspace

Optional working directory used by hooks (e.g., checkpoints, logs).

TYPE: Path | str | None DEFAULT: None

device

Device for the model and tensors.

TYPE: device | str | int | None DEFAULT: None

no_sync_accumulate

Whether to call no_sync on distributed modules during accumulation.

TYPE: bool DEFAULT: True

state_dict_options

Torch distributed checkpoint options.

TYPE: StateDictOptions | None DEFAULT: None

logger

Logger instance; a default logger is created when omitted.

TYPE: Logger | None DEFAULT: None

METHOD DESCRIPTION
build_data_loader

Return the training data iterator.

build_model

Construct and return the model.

build_optimizer

Create the optimizer for the model.

build_lr_scheduler

Optionally create a learning-rate scheduler.

build_hooks

Return hooks to run during training.

build_grad_scaler

Create the gradient scaler used for mixed precision.

train

Run the training loop until max_steps are completed.

forward

Perform a forward pass and return loss plus records for logging.

log

Dispatch numeric records to hooks (e.g., trackers or stdout).

log_images

Dispatch image records to hooks.

build_data_loader

build_data_loader() -> Iterable

Return the training data iterator.

build_model

build_model() -> nn.Module

Construct and return the model.

build_optimizer

build_optimizer() -> torch.optim.Optimizer

Create the optimizer for the model.

build_lr_scheduler

build_lr_scheduler() -> torch.optim.lr_scheduler.LRScheduler | None

Optionally create a learning-rate scheduler.

build_hooks

build_hooks() -> list[BaseHook]

Return hooks to run during training.

build_grad_scaler

build_grad_scaler() -> torch.amp.GradScaler

Create the gradient scaler used for mixed precision.

train

train()

Run the training loop until max_steps are completed.

forward

forward(input: Any) -> tuple[torch.Tensor | None, Records]

Perform a forward pass and return loss plus records for logging.

PARAMETER DESCRIPTION
input

Batch yielded by the data loader.

TYPE: Any

RETURNS DESCRIPTION
Tensor | None

The loss (None skips backward/step; if using DDP/FSDP, avoid invoking the wrapped module's forward in that case).

Records

A nested dict of numeric metrics that will be averaged and emitted to hooks.

log

log(records: dict[str, Any], dry_run: bool = False)

Dispatch numeric records to hooks (e.g., trackers or stdout).

PARAMETER DESCRIPTION
records

Nested dict of numeric metrics to log.

TYPE: dict[str, Any]

dry_run

If True, hooks should avoid side effects and only report intent.

TYPE: bool DEFAULT: False

log_images

log_images(records: dict[str, Any], dry_run: bool = False)

Dispatch image records to hooks.

PARAMETER DESCRIPTION
records

Nested dict of images to log.

TYPE: dict[str, Any]

dry_run

If True, hooks should avoid side effects and only report intent.

TYPE: bool DEFAULT: False

BaseHook

Lifecycle hooks for BaseTrainer.

CheckpointingHook

Bases: BaseHook

Save and optionally restore checkpoints at regular intervals.

PARAMETER DESCRIPTION
interval

Save every interval steps.

TYPE: int

keep_previous

Keep the last N checkpoints in addition to the latest.

TYPE: int DEFAULT: 0

keep_interval

Keep checkpoints every keep_interval steps.

TYPE: int DEFAULT: 0

path

Directory (relative to workspace unless absolute) for checkpoints.

TYPE: Path | str DEFAULT: 'checkpoints'

load

Path to load at startup or "latest" to auto-resume.

TYPE: Path | str | Literal['latest'] | None DEFAULT: 'latest'

exit_signals

Signals that trigger a checkpoint then exit.

TYPE: list[Signals] | Signals DEFAULT: None

exit_code

Exit code after handling an exit signal.

TYPE: int | Literal['128+signal'] DEFAULT: '128+signal'

exit_wait

Optional sleep before exit (useful for schedulers).

TYPE: timedelta | float DEFAULT: 0.0

CudaMaxMemoryHook

Bases: BaseHook

Record peak CUDA memory per step into trainer.step_info.

LoggingHook

Bases: _StatsHook

Aggregate stats and forward them to trainer.log.

PARAMETER DESCRIPTION
interval

Log every N steps.

TYPE: int DEFAULT: 10

sync

If True, aggregate across distributed ranks.

TYPE: bool DEFAULT: True

ProgressHook

Bases: _StatsHook

Log progress to stdout with optional metrics, ETA, and memory.

PARAMETER DESCRIPTION
interval

Log every N steps.

TYPE: int DEFAULT: 1

with_records

Include per-step records in the log line.

TYPE: bool DEFAULT: False

sync

If True, aggregate across distributed ranks.

TYPE: bool DEFAULT: False

eta_warmup

Steps to warm up ETA calculation.

TYPE: int DEFAULT: 10

show_units

Whether to print units (s, GiB) alongside values.

TYPE: bool DEFAULT: True

EmaHook

Bases: BaseHook

Maintain an exponential moving average of model weights.

PARAMETER DESCRIPTION
decay

EMA decay rate.

TYPE: float

WandbHook

Bases: BaseHook

Log metrics and images to Weights & Biases (rank 0 only).

PARAMETER DESCRIPTION
project

W&B project name.

TYPE: str

config

Optional config dict or JSON file path to log.

TYPE: dict[str, Any] | str | None DEFAULT: None

tags

Optional tag list.

TYPE: Sequence[str] | None DEFAULT: None

image_format

File format for images or a callable to derive it per key.

TYPE: str | None | Callable[[str], str | None] DEFAULT: 'png'

**wandb_kwargs

Extra arguments forwarded to wandb.init.

DEFAULT: {}

ImageFileLoggerHook

Bases: BaseHook

Persist logged images to workspace/visualizations on rank 0.

PARAMETER DESCRIPTION
image_format

File extension or callable taking the leaf key.

TYPE: str | Callable[[str], str] DEFAULT: 'png'

LossNoneWarning

Bases: UserWarning

Warning raised when forward returns None in distributed contexts.

map_nested_tensor

map_nested_tensor(f: Callable[[Tensor], Any], obj: Any)

Apply f to every tensor contained in a nested structure.