HGCN API¶
Main NAICS Contrastive Learning Model combining: - MultiChannelEncoder with LoRA fine-tuning and MoE - Hyperbolic embeddings using the Lorentz model - Curriculum learning with structure-aware negative sampling - Multi-level supervision and false negative detection
The model is decomposed into functional mixins for maintainability: - DistributedMixin: Global batch sampling utilities - LossMixin: Loss computation methods - CurriculumMixin: Curriculum learning logic - LoggingMixin: Logging utilities - ValidationMixin: Validation step and evaluation - OptimizerMixin: Optimizer configuration
GlobalNegativeContext
dataclass
¶
Container for global negative embedding context in distributed training.
NAICSContrastiveModel
¶
Bases: DistributedMixin, LossMixin, CurriculumMixin, LoggingMixin, ValidationMixin, OptimizerMixin, LightningModule
NAICS Contrastive Learning Model for learning hierarchical NAICS code embeddings.
This model combines: - MultiChannelEncoder: LoRA-tuned transformer with Mixture of Experts - Hyperbolic embeddings: Lorentz model for hierarchical representation - Curriculum learning: Structure-aware dynamic curriculum (SADC) - Multiple loss functions: Contrastive, hierarchy preservation, LambdaRank
The implementation is decomposed into functional mixins: - DistributedMixin: Multi-GPU global batch sampling - LossMixin: Loss computation (hierarchy, ranking, regularization) - CurriculumMixin: Hard negative mining, router-guided sampling - LoggingMixin: Training and validation metric logging - ValidationMixin: Validation step and evaluation metrics - OptimizerMixin: Optimizer and scheduler configuration
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
base_model_name
|
str
|
HuggingFace model name for the base encoder |
'sentence-transformers/all-mpnet-base-v2'
|
lora_r
|
int
|
LoRA rank |
8
|
lora_alpha
|
int
|
LoRA alpha scaling factor |
16
|
lora_dropout
|
float
|
LoRA dropout rate |
0.1
|
num_experts
|
int
|
Number of MoE experts |
4
|
top_k
|
int
|
Number of experts to select per token |
2
|
moe_hidden_dim
|
int
|
Hidden dimension of MoE layers |
1024
|
temperature
|
float
|
Temperature for InfoNCE loss |
0.07
|
curvature
|
float
|
Hyperbolic space curvature |
1.0
|
hierarchy_weight
|
float
|
Weight for hierarchy preservation loss |
0.1
|
rank_order_weight
|
float
|
Weight for LambdaRank loss |
0.15
|
radius_reg_weight
|
float
|
Weight for radius regularization |
0.01
|
level_radius_weight
|
float
|
Weight for level-aware radius prior |
0.05
|
learning_rate
|
float
|
Base learning rate |
0.0002
|
weight_decay
|
float
|
AdamW weight decay |
0.01
|
warmup_steps
|
int
|
Number of warmup steps |
500
|
use_warmup_cosine
|
bool
|
Use warmup + cosine decay scheduler |
False
|
load_balancing_coef
|
float
|
MoE load balancing coefficient |
0.01
|
fn_curriculum_start_epoch
|
int
|
Epoch to start false negative curriculum |
10
|
fn_cluster_every_n_epochs
|
int
|
Clustering frequency for pseudo-labels |
5
|
fn_num_clusters
|
int
|
Number of clusters for pseudo-labeling |
500
|
distance_matrix_path
|
Optional[str]
|
Path to ground truth distance matrix |
None
|
eval_every_n_epochs
|
int
|
Evaluation frequency |
1
|
eval_sample_size
|
int
|
Number of samples for evaluation |
500
|
tree_distance_alpha
|
float
|
Tree distance scaling factor |
1.5
|
base_margin
|
float
|
Base margin for adaptive margin |
0.5
|
curriculum_phase1_end
|
float
|
End of curriculum phase 1 (fraction) |
0.3
|
curriculum_phase2_end
|
float
|
End of curriculum phase 2 (fraction) |
0.7
|
curriculum_phase3_end
|
float
|
End of curriculum phase 3 (fraction) |
1.0
|
sibling_distance_threshold
|
float
|
Threshold for sibling relationships |
2.0
|
curriculum_phase_mode
|
str
|
Curriculum phase mode |
'three_phase'
|
curriculum_anneal
|
Optional[Dict[str, float]]
|
Annealing configuration for curriculum |
None
|
false_negative_config
|
Optional[Union[FalseNegativeConfig, Dict[str, Any]]]
|
Configuration for false negative handling |
None
|
relations_parquet_path
|
Optional[str]
|
Path to NAICS relations parquet |
None
|
parent_eval_top_k
|
int
|
Top-k for parent retrieval evaluation |
1
|
child_eval_top_k
|
int
|
Top-k for child retrieval evaluation |
5
|
forward(channel_inputs)
¶
Forward pass through the encoder.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
channel_inputs
|
Dict[str, Dict[str, Tensor]]
|
Dictionary of channel inputs with tokenized text |
required |
Returns:
| Type | Description |
|---|---|
Dict[str, Tensor]
|
Dictionary containing: |
Dict[str, Tensor]
|
|
Dict[str, Tensor]
|
|
Dict[str, Tensor]
|
|
training_step(batch, batch_idx)
¶
Perform a single training step.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch
|
Dict[str, Any]
|
Training batch with anchor, positive, negatives, and metadata |
required |
batch_idx
|
int
|
Batch index |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
Total loss for optimization |
gather_embeddings_global(local_embeddings, world_size=None)
¶
Gather embeddings from all GPUs using all_gather with gradient support.
Issue #19: Global Batch Sampling - Collect embeddings from all ranks to enable hard negative mining across the global batch.
This function uses torch.distributed.all_gather which preserves gradients, ensuring that gradients flow back through the gather operation during backprop.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
local_embeddings
|
Tensor
|
Local embeddings tensor (N_local, D) with requires_grad=True |
required |
world_size
|
Optional[int]
|
Number of GPUs (auto-detected if None) |
None
|
Returns:
| Type | Description |
|---|---|
Tensor
|
Global embeddings tensor (N_global, D) where N_global = N_local * world_size |
Tensor
|
Gradients will flow back through this operation during backprop. |
gather_negative_codes_global(local_negative_codes, world_size=None)
¶
Gather negative codes from all GPUs for false negative masking.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
local_negative_codes
|
List[List[str]]
|
Local negative codes (batch_size, k_negatives) |
required |
world_size
|
Optional[int]
|
Number of GPUs (auto-detected if None) |
None
|
Returns:
| Type | Description |
|---|---|
List[List[str]]
|
Global negative codes list |