Training Utilities¶
Helper functions for training orchestration and result collection.
Overview¶
The training utilities module provides reusable components for the training
workflow, including hardware detection, configuration parsing, checkpoint
management, and summary artifact generation.
Usage¶
from naics_embedder.utils.training import (
detect_hardware,
parse_config_overrides,
resolve_checkpoint,
save_training_summary,
)
# Detect hardware
hardware = detect_hardware(log_info=True)
print(f"Training on {hardware.accelerator}")
# Parse overrides
overrides, invalid = parse_config_overrides(['lr=1e-4', 'epochs=10'])
# Resolve checkpoint
checkpoint_info = resolve_checkpoint('last', Path('checkpoints'), 'experiment')
Data Classes¶
HardwareInfo¶
Container for detected hardware configuration.
CheckpointInfo¶
Resolved checkpoint path and metadata.
TrainingResult¶
Structured result from a completed training run.
API Reference¶
Training orchestration utilities for NAICS Embedder.
This module provides helper functions to simplify training setup by extracting common operations into reusable components. These utilities handle hardware detection, configuration parsing, checkpoint management, and trainer creation.
Functions:
| Name | Description |
|---|---|
detect_hardware |
Detect available accelerators and optimal precision. |
get_gpu_memory_info |
Query current GPU memory usage. |
parse_config_overrides |
Parse and validate command-line config overrides. |
resolve_checkpoint |
Resolve checkpoint path from user input. |
create_trainer |
Create a configured PyTorch Lightning Trainer. |
TrainingResult |
Structured result from a training run. |
HardwareInfo
dataclass
¶
Hardware configuration detected for training.
Attributes:
| Name | Type | Description |
|---|---|---|
accelerator |
str
|
The accelerator type (cuda, mps, cpu). |
precision |
str
|
Recommended precision setting (16-mixed, 32-true). |
num_devices |
int
|
Number of available devices. |
gpu_memory |
Optional[Dict[str, float]]
|
Optional GPU memory information dictionary. |
CheckpointInfo
dataclass
¶
Resolved checkpoint information.
Attributes:
| Name | Type | Description |
|---|---|---|
path |
Optional[str]
|
Resolved filesystem path to the checkpoint, or None. |
is_same_stage |
bool
|
Whether checkpoint is from the same experiment stage. |
exists |
bool
|
Whether the checkpoint file exists. |
TrainingResult
dataclass
¶
Structured result from a training run.
Provides a clean interface for accessing training outputs, metrics, and paths for downstream processing or testing.
Attributes:
| Name | Type | Description |
|---|---|---|
best_checkpoint_path |
Optional[str]
|
Path to the best model checkpoint. |
last_checkpoint_path |
Optional[str]
|
Path to the last model checkpoint. |
config_path |
Optional[str]
|
Path to the saved configuration file. |
best_loss |
Optional[float]
|
Best validation loss achieved. |
stopped_epoch |
int
|
Epoch at which training stopped (early stopping or max). |
early_stopped |
bool
|
Whether early stopping was triggered. |
metrics |
Dict[str, Any]
|
Dictionary of final metrics. |
detect_hardware(log_info=False)
¶
Detect available hardware and recommend training settings.
Queries the system for CUDA, MPS, or CPU availability and returns appropriate accelerator and precision settings for PyTorch Lightning.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
log_info
|
bool
|
If True, log detailed hardware information. |
False
|
Returns:
| Type | Description |
|---|---|
HardwareInfo
|
HardwareInfo with detected accelerator, precision, device count, |
HardwareInfo
|
and optional GPU memory information. |
Example
hw = detect_hardware(log_info=True) print(f'Training on {hw.accelerator} with {hw.precision} precision')
get_gpu_memory_info()
¶
Query current GPU memory usage.
Returns memory statistics including total, reserved, allocated, and free memory in gigabytes, along with utilization percentage.
Returns:
| Type | Description |
|---|---|
Optional[Dict[str, float]]
|
Dictionary with memory statistics, or None if CUDA unavailable. |
Example
info = get_gpu_memory_info() if info: ... print(f'GPU Memory: {info["free_gb"]:.1f} GB free')
parse_config_overrides(overrides)
¶
Parse and validate command-line configuration overrides.
Converts a list of key=value strings into a dictionary suitable for
use with Config.override(). Invalid overrides are collected and
returned for warning messages.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
overrides
|
Optional[List[str]]
|
List of override strings like |
required |
Returns:
| Type | Description |
|---|---|
Tuple[Dict[str, Any], List[str]]
|
Tuple of (valid_overrides_dict, list_of_invalid_override_strings). |
Example
overrides = ['training.learning_rate=1e-4', 'invalid', 'batch_size=32'] valid, invalid = parse_config_overrides(overrides) print(valid) # {'training.learning_rate': 0.0001, 'batch_size': 32} print(invalid) # ['invalid']
resolve_checkpoint(ckpt_path, checkpoint_dir, experiment_name)
¶
Resolve a checkpoint path from user input.
Handles three cases:
1. None: No checkpoint specified, start fresh.
2. "last" or "last.ckpt": Auto-detect last checkpoint in experiment dir.
3. Explicit path: Validate and resolve the provided path.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
ckpt_path
|
Optional[str]
|
User-provided checkpoint path or keyword. |
required |
checkpoint_dir
|
Path
|
Base directory for checkpoints. |
required |
experiment_name
|
str
|
Name of the current experiment. |
required |
Returns:
| Type | Description |
|---|---|
CheckpointInfo
|
CheckpointInfo with resolved path and metadata. |
Example
info = resolve_checkpoint('last', Path('checkpoints'), '01_text') if info.exists: ... print(f'Resuming from {info.path}')
create_trainer(cfg, hardware, checkpoint_dir, callbacks=None, tb_logger=None)
¶
Create a configured PyTorch Lightning Trainer.
Sets up the trainer with appropriate callbacks, logging, and hardware settings based on the provided configuration.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
cfg
|
Config
|
Training configuration. |
required |
hardware
|
HardwareInfo
|
Detected hardware information. |
required |
checkpoint_dir
|
Path
|
Directory for saving checkpoints. |
required |
callbacks
|
Optional[List[Callback]]
|
Optional additional callbacks to include. |
None
|
tb_logger
|
Optional[TensorBoardLogger]
|
Optional TensorBoard logger (created if not provided). |
None
|
Returns:
| Type | Description |
|---|---|
Tuple[Trainer, ModelCheckpoint, EarlyStopping]
|
Tuple of (Trainer, ModelCheckpoint callback, EarlyStopping callback). |
Example
hw = detect_hardware() trainer, ckpt_cb, es_cb = create_trainer(cfg, hw, Path('checkpoints')) trainer.fit(model, datamodule)
collect_training_result(checkpoint_callback, early_stopping, config_path=None)
¶
Collect results from a completed training run.
Gathers checkpoint paths, metrics, and early stopping information into a structured result object for downstream processing.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
checkpoint_callback
|
ModelCheckpoint
|
The ModelCheckpoint callback from training. |
required |
early_stopping
|
EarlyStopping
|
The EarlyStopping callback from training. |
required |
config_path
|
Optional[str]
|
Optional path where config was saved. |
None
|
Returns:
| Type | Description |
|---|---|
TrainingResult
|
TrainingResult with all training outputs and metrics. |
Example
trainer.fit(model, datamodule) result = collect_training_result(ckpt_cb, es_cb, 'config.yaml') print(f'Best loss: {result.best_loss:.4f}')
save_training_summary(result, config, hardware, output_dir, format='both')
¶
Save training summary artifacts for downstream evaluation and documentation.
Creates YAML and/or JSON summary files containing training results, configuration snapshot, and hardware information. These artifacts can be used for evaluation scripts, MkDocs documentation, or CI/CD pipelines.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
result
|
TrainingResult
|
TrainingResult from the completed training run. |
required |
config
|
Config
|
Configuration used for training. |
required |
hardware
|
HardwareInfo
|
Hardware information used during training. |
required |
output_dir
|
Path
|
Directory to save summary files. |
required |
format
|
str
|
Output format - 'yaml', 'json', or 'both'. |
'both'
|
Returns:
| Type | Description |
|---|---|
Dict[str, str]
|
Dictionary mapping format to output file path. |
Example
result = collect_training_result(ckpt_cb, es_cb) paths = save_training_summary(result, cfg, hw, Path('outputs')) print(f'Summary saved to: {paths}')