HGCN API

Main NAICS Contrastive Learning Model combining: - MultiChannelEncoder with LoRA fine-tuning and MoE - Hyperbolic embeddings using the Lorentz model - Curriculum learning with structure-aware negative sampling - Multi-level supervision and false negative detection

The model is decomposed into functional mixins for maintainability: - DistributedMixin: Global batch sampling utilities - LossMixin: Loss computation methods - CurriculumMixin: Curriculum learning logic - LoggingMixin: Logging utilities - ValidationMixin: Validation step and evaluation - OptimizerMixin: Optimizer configuration

GlobalNegativeContext dataclass

Container for global negative embedding context in distributed training.

NAICSContrastiveModel

Bases: DistributedMixin, LossMixin, CurriculumMixin, LoggingMixin, ValidationMixin, OptimizerMixin, LightningModule

NAICS Contrastive Learning Model for learning hierarchical NAICS code embeddings.

This model combines: - MultiChannelEncoder: LoRA-tuned transformer with Mixture of Experts - Hyperbolic embeddings: Lorentz model for hierarchical representation - Curriculum learning: Structure-aware dynamic curriculum (SADC) - Multiple loss functions: Contrastive, hierarchy preservation, LambdaRank

The implementation is decomposed into functional mixins: - DistributedMixin: Multi-GPU global batch sampling - LossMixin: Loss computation (hierarchy, ranking, regularization) - CurriculumMixin: Hard negative mining, router-guided sampling - LoggingMixin: Training and validation metric logging - ValidationMixin: Validation step and evaluation metrics - OptimizerMixin: Optimizer and scheduler configuration

Parameters:

Name Type Description Default
base_model_name str

HuggingFace model name for the base encoder

'sentence-transformers/all-mpnet-base-v2'
lora_r int

LoRA rank

8
lora_alpha int

LoRA alpha scaling factor

16
lora_dropout float

LoRA dropout rate

0.1
num_experts int

Number of MoE experts

4
top_k int

Number of experts to select per token

2
moe_hidden_dim int

Hidden dimension of MoE layers

1024
temperature float

Temperature for InfoNCE loss

0.07
curvature float

Hyperbolic space curvature

1.0
hierarchy_weight float

Weight for hierarchy preservation loss

0.1
rank_order_weight float

Weight for LambdaRank loss

0.15
radius_reg_weight float

Weight for radius regularization

0.01
level_radius_weight float

Weight for level-aware radius prior

0.05
learning_rate float

Base learning rate

0.0002
weight_decay float

AdamW weight decay

0.01
warmup_steps int

Number of warmup steps

500
use_warmup_cosine bool

Use warmup + cosine decay scheduler

False
load_balancing_coef float

MoE load balancing coefficient

0.01
fn_curriculum_start_epoch int

Epoch to start false negative curriculum

10
fn_cluster_every_n_epochs int

Clustering frequency for pseudo-labels

5
fn_num_clusters int

Number of clusters for pseudo-labeling

500
distance_matrix_path Optional[str]

Path to ground truth distance matrix

None
eval_every_n_epochs int

Evaluation frequency

1
eval_sample_size int

Number of samples for evaluation

500
tree_distance_alpha float

Tree distance scaling factor

1.5
base_margin float

Base margin for adaptive margin

0.5
curriculum_phase1_end float

End of curriculum phase 1 (fraction)

0.3
curriculum_phase2_end float

End of curriculum phase 2 (fraction)

0.7
curriculum_phase3_end float

End of curriculum phase 3 (fraction)

1.0
sibling_distance_threshold float

Threshold for sibling relationships

2.0
curriculum_phase_mode str

Curriculum phase mode

'three_phase'
curriculum_anneal Optional[Dict[str, float]]

Annealing configuration for curriculum

None
false_negative_config Optional[Union[FalseNegativeConfig, Dict[str, Any]]]

Configuration for false negative handling

None
relations_parquet_path Optional[str]

Path to NAICS relations parquet

None
parent_eval_top_k int

Top-k for parent retrieval evaluation

1
child_eval_top_k int

Top-k for child retrieval evaluation

5

forward(channel_inputs)

Forward pass through the encoder.

Parameters:

Name Type Description Default
channel_inputs Dict[str, Dict[str, Tensor]]

Dictionary of channel inputs with tokenized text

required

Returns:

Type Description
Dict[str, Tensor]

Dictionary containing:

Dict[str, Tensor]
  • embedding: Hyperbolic embeddings (batch_size, embed_dim + 1)
Dict[str, Tensor]
  • gate_probs: MoE gate probabilities (batch_size, num_experts)
Dict[str, Tensor]
  • top_k_indices: Selected expert indices (batch_size, top_k)

training_step(batch, batch_idx)

Perform a single training step.

Parameters:

Name Type Description Default
batch Dict[str, Any]

Training batch with anchor, positive, negatives, and metadata

required
batch_idx int

Batch index

required

Returns:

Type Description
Tensor

Total loss for optimization

gather_embeddings_global(local_embeddings, world_size=None)

Gather embeddings from all GPUs using all_gather with gradient support.

Issue #19: Global Batch Sampling - Collect embeddings from all ranks to enable hard negative mining across the global batch.

This function uses torch.distributed.all_gather which preserves gradients, ensuring that gradients flow back through the gather operation during backprop.

Parameters:

Name Type Description Default
local_embeddings Tensor

Local embeddings tensor (N_local, D) with requires_grad=True

required
world_size Optional[int]

Number of GPUs (auto-detected if None)

None

Returns:

Type Description
Tensor

Global embeddings tensor (N_global, D) where N_global = N_local * world_size

Tensor

Gradients will flow back through this operation during backprop.

gather_negative_codes_global(local_negative_codes, world_size=None)

Gather negative codes from all GPUs for false negative masking.

Parameters:

Name Type Description Default
local_negative_codes List[List[str]]

Local negative codes (batch_size, k_negatives)

required
world_size Optional[int]

Number of GPUs (auto-detected if None)

None

Returns:

Type Description
List[List[str]]

Global negative codes list