DataModule API

NAICSDataModule

Bases: LightningDataModule

DataModule for NAICS embedding training with pre-sampled or on-the-fly triplets.

on_train_epoch_start()

Update dataset epoch for on-the-fly sampling with difficulty curriculum.

prepare_data()

Build all caches before worker processes are spawned.

setup(stage=None)

Load caches and create datasets.

train_dataloader()

Create training dataloader with shuffling enabled.

val_dataloader()

Create validation dataloader.

NAICSMapDataset

Bases: Dataset

Map-style dataset for pre-sampled triplets with tokenized embeddings.

__getitem__(idx)

Get a single triplet item by index.

__init__(triplet_rows, token_cache)

Initialize the map-style dataset.

Parameters:

Name Type Description Default
triplet_rows List[Dict[str, Any]]

List of triplet dictionaries with anchor/positive/negative info

required
token_cache Dict[int, Dict[str, Any]]

Dictionary mapping index to tokenized embeddings

required

Phase1MapDataset

Bases: Dataset

Map-style dataset with on-the-fly Phase 1 negative sampling and difficulty curriculum.

Instead of pre-computing all negatives for multiple epochs, this dataset: 1. Pre-computes (anchor, positive) pairs as the fixed index space 2. Samples negatives on-the-fly in getitem() with epoch-aware seeds 3. Applies difficulty curriculum: easy -> semi-hard -> hard across Phase 1 4. Provides oversampled candidates for Phase 2+ hard negative mining

Attributes:

Name Type Description
cfg

Streaming configuration

sampling_cfg

Sampling strategy configuration

token_cache

Pre-computed tokenized embeddings

phase1_end_epoch

Epoch at which Phase 1 ends

epoch

Current training epoch (updated via set_epoch)

__getitem__(idx)

Get a single triplet item by index with on-the-fly negative sampling.

Returns:

Type Description
Optional[Dict[str, Any]]

Dictionary with anchor, positive, selected negatives (Phase 1),

Optional[Dict[str, Any]]

and all candidates (for Phase 2+ HNM). Returns None if embeddings

Optional[Dict[str, Any]]

are missing.

__init__(cfg, sampling_cfg, token_cache, phase1_end_epoch)

Initialize the Phase 1 map dataset.

Parameters:

Name Type Description Default
cfg StreamingConfig

Streaming configuration with sampling parameters

required
sampling_cfg SamplingConfig

Sampling strategy configuration

required
token_cache Dict[int, Dict[str, Any]]

Dictionary mapping index to tokenized embeddings

required
phase1_end_epoch int

Epoch at which Phase 1 ends (for curriculum progress)

required

set_epoch(epoch)

Update the current epoch for different negative sampling.

collate_fn(batch)

Collate function to batch triplets for training. Each batch item represents a single positive.