DataModule API¶
NAICSDataModule
¶
Bases: LightningDataModule
DataModule for NAICS embedding training with pre-sampled or on-the-fly triplets.
on_train_epoch_start()
¶
Update dataset epoch for on-the-fly sampling with difficulty curriculum.
prepare_data()
¶
Build all caches before worker processes are spawned.
setup(stage=None)
¶
Load caches and create datasets.
train_dataloader()
¶
Create training dataloader with shuffling enabled.
val_dataloader()
¶
Create validation dataloader.
NAICSMapDataset
¶
Bases: Dataset
Map-style dataset for pre-sampled triplets with tokenized embeddings.
__getitem__(idx)
¶
Get a single triplet item by index.
__init__(triplet_rows, token_cache)
¶
Initialize the map-style dataset.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
triplet_rows
|
List[Dict[str, Any]]
|
List of triplet dictionaries with anchor/positive/negative info |
required |
token_cache
|
Dict[int, Dict[str, Any]]
|
Dictionary mapping index to tokenized embeddings |
required |
Phase1MapDataset
¶
Bases: Dataset
Map-style dataset with on-the-fly Phase 1 negative sampling and difficulty curriculum.
Instead of pre-computing all negatives for multiple epochs, this dataset: 1. Pre-computes (anchor, positive) pairs as the fixed index space 2. Samples negatives on-the-fly in getitem() with epoch-aware seeds 3. Applies difficulty curriculum: easy -> semi-hard -> hard across Phase 1 4. Provides oversampled candidates for Phase 2+ hard negative mining
Attributes:
| Name | Type | Description |
|---|---|---|
cfg |
Streaming configuration |
|
sampling_cfg |
Sampling strategy configuration |
|
token_cache |
Pre-computed tokenized embeddings |
|
phase1_end_epoch |
Epoch at which Phase 1 ends |
|
epoch |
Current training epoch (updated via set_epoch) |
__getitem__(idx)
¶
Get a single triplet item by index with on-the-fly negative sampling.
Returns:
| Type | Description |
|---|---|
Optional[Dict[str, Any]]
|
Dictionary with anchor, positive, selected negatives (Phase 1), |
Optional[Dict[str, Any]]
|
and all candidates (for Phase 2+ HNM). Returns None if embeddings |
Optional[Dict[str, Any]]
|
are missing. |
__init__(cfg, sampling_cfg, token_cache, phase1_end_epoch)
¶
Initialize the Phase 1 map dataset.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
cfg
|
StreamingConfig
|
Streaming configuration with sampling parameters |
required |
sampling_cfg
|
SamplingConfig
|
Sampling strategy configuration |
required |
token_cache
|
Dict[int, Dict[str, Any]]
|
Dictionary mapping index to tokenized embeddings |
required |
phase1_end_epoch
|
int
|
Epoch at which Phase 1 ends (for curriculum progress) |
required |
set_epoch(epoch)
¶
Update the current epoch for different negative sampling.
collate_fn(batch)
¶
Collate function to batch triplets for training. Each batch item represents a single positive.