Training Utilities

Helper functions for training orchestration and result collection.

Overview

The training utilities module provides reusable components for the training workflow, including hardware detection, configuration parsing, checkpoint management, and summary artifact generation.

Usage

from naics_embedder.utils.training import (
    detect_hardware,
    parse_config_overrides,
    resolve_checkpoint,
    save_training_summary,
)

# Detect hardware
hardware = detect_hardware(log_info=True)
print(f"Training on {hardware.accelerator}")

# Parse overrides
overrides, invalid = parse_config_overrides(['lr=1e-4', 'epochs=10'])

# Resolve checkpoint
checkpoint_info = resolve_checkpoint('last', Path('checkpoints'), 'experiment')

Data Classes

HardwareInfo

Container for detected hardware configuration.

CheckpointInfo

Resolved checkpoint path and metadata.

TrainingResult

Structured result from a completed training run.

API Reference

Training orchestration utilities for NAICS Embedder.

This module provides helper functions to simplify training setup by extracting common operations into reusable components. These utilities handle hardware detection, configuration parsing, checkpoint management, and trainer creation.

Functions:

Name Description
detect_hardware

Detect available accelerators and optimal precision.

get_gpu_memory_info

Query current GPU memory usage.

parse_config_overrides

Parse and validate command-line config overrides.

resolve_checkpoint

Resolve checkpoint path from user input.

create_trainer

Create a configured PyTorch Lightning Trainer.

TrainingResult

Structured result from a training run.

HardwareInfo dataclass

Hardware configuration detected for training.

Attributes:

Name Type Description
accelerator str

The accelerator type (cuda, mps, cpu).

precision str

Recommended precision setting (16-mixed, 32-true).

num_devices int

Number of available devices.

gpu_memory Optional[Dict[str, float]]

Optional GPU memory information dictionary.

CheckpointInfo dataclass

Resolved checkpoint information.

Attributes:

Name Type Description
path Optional[str]

Resolved filesystem path to the checkpoint, or None.

is_same_stage bool

Whether checkpoint is from the same experiment stage.

exists bool

Whether the checkpoint file exists.

TrainingResult dataclass

Structured result from a training run.

Provides a clean interface for accessing training outputs, metrics, and paths for downstream processing or testing.

Attributes:

Name Type Description
best_checkpoint_path Optional[str]

Path to the best model checkpoint.

last_checkpoint_path Optional[str]

Path to the last model checkpoint.

config_path Optional[str]

Path to the saved configuration file.

best_loss Optional[float]

Best validation loss achieved.

stopped_epoch int

Epoch at which training stopped (early stopping or max).

early_stopped bool

Whether early stopping was triggered.

metrics Dict[str, Any]

Dictionary of final metrics.

detect_hardware(log_info=False)

Detect available hardware and recommend training settings.

Queries the system for CUDA, MPS, or CPU availability and returns appropriate accelerator and precision settings for PyTorch Lightning.

Parameters:

Name Type Description Default
log_info bool

If True, log detailed hardware information.

False

Returns:

Type Description
HardwareInfo

HardwareInfo with detected accelerator, precision, device count,

HardwareInfo

and optional GPU memory information.

Example

hw = detect_hardware(log_info=True) print(f'Training on {hw.accelerator} with {hw.precision} precision')

get_gpu_memory_info()

Query current GPU memory usage.

Returns memory statistics including total, reserved, allocated, and free memory in gigabytes, along with utilization percentage.

Returns:

Type Description
Optional[Dict[str, float]]

Dictionary with memory statistics, or None if CUDA unavailable.

Example

info = get_gpu_memory_info() if info: ... print(f'GPU Memory: {info["free_gb"]:.1f} GB free')

parse_config_overrides(overrides)

Parse and validate command-line configuration overrides.

Converts a list of key=value strings into a dictionary suitable for use with Config.override(). Invalid overrides are collected and returned for warning messages.

Parameters:

Name Type Description Default
overrides Optional[List[str]]

List of override strings like training.learning_rate=1e-4.

required

Returns:

Type Description
Tuple[Dict[str, Any], List[str]]

Tuple of (valid_overrides_dict, list_of_invalid_override_strings).

Example

overrides = ['training.learning_rate=1e-4', 'invalid', 'batch_size=32'] valid, invalid = parse_config_overrides(overrides) print(valid) # {'training.learning_rate': 0.0001, 'batch_size': 32} print(invalid) # ['invalid']

resolve_checkpoint(ckpt_path, checkpoint_dir, experiment_name)

Resolve a checkpoint path from user input.

Handles three cases: 1. None: No checkpoint specified, start fresh. 2. "last" or "last.ckpt": Auto-detect last checkpoint in experiment dir. 3. Explicit path: Validate and resolve the provided path.

Parameters:

Name Type Description Default
ckpt_path Optional[str]

User-provided checkpoint path or keyword.

required
checkpoint_dir Path

Base directory for checkpoints.

required
experiment_name str

Name of the current experiment.

required

Returns:

Type Description
CheckpointInfo

CheckpointInfo with resolved path and metadata.

Example

info = resolve_checkpoint('last', Path('checkpoints'), '01_text') if info.exists: ... print(f'Resuming from {info.path}')

create_trainer(cfg, hardware, checkpoint_dir, callbacks=None, tb_logger=None)

Create a configured PyTorch Lightning Trainer.

Sets up the trainer with appropriate callbacks, logging, and hardware settings based on the provided configuration.

Parameters:

Name Type Description Default
cfg Config

Training configuration.

required
hardware HardwareInfo

Detected hardware information.

required
checkpoint_dir Path

Directory for saving checkpoints.

required
callbacks Optional[List[Callback]]

Optional additional callbacks to include.

None
tb_logger Optional[TensorBoardLogger]

Optional TensorBoard logger (created if not provided).

None

Returns:

Type Description
Tuple[Trainer, ModelCheckpoint, EarlyStopping]

Tuple of (Trainer, ModelCheckpoint callback, EarlyStopping callback).

Example

hw = detect_hardware() trainer, ckpt_cb, es_cb = create_trainer(cfg, hw, Path('checkpoints')) trainer.fit(model, datamodule)

collect_training_result(checkpoint_callback, early_stopping, config_path=None)

Collect results from a completed training run.

Gathers checkpoint paths, metrics, and early stopping information into a structured result object for downstream processing.

Parameters:

Name Type Description Default
checkpoint_callback ModelCheckpoint

The ModelCheckpoint callback from training.

required
early_stopping EarlyStopping

The EarlyStopping callback from training.

required
config_path Optional[str]

Optional path where config was saved.

None

Returns:

Type Description
TrainingResult

TrainingResult with all training outputs and metrics.

Example

trainer.fit(model, datamodule) result = collect_training_result(ckpt_cb, es_cb, 'config.yaml') print(f'Best loss: {result.best_loss:.4f}')

save_training_summary(result, config, hardware, output_dir, format='both')

Save training summary artifacts for downstream evaluation and documentation.

Creates YAML and/or JSON summary files containing training results, configuration snapshot, and hardware information. These artifacts can be used for evaluation scripts, MkDocs documentation, or CI/CD pipelines.

Parameters:

Name Type Description Default
result TrainingResult

TrainingResult from the completed training run.

required
config Config

Configuration used for training.

required
hardware HardwareInfo

Hardware information used during training.

required
output_dir Path

Directory to save summary files.

required
format str

Output format - 'yaml', 'json', or 'both'.

'both'

Returns:

Type Description
Dict[str, str]

Dictionary mapping format to output file path.

Example

result = collect_training_result(ckpt_cb, es_cb) paths = save_training_summary(result, cfg, hw, Path('outputs')) print(f'Summary saved to: {paths}')