Hierarchy Preservation Loss¶
Bases: Module
Loss component that encourages embedding distances to match tree distances. This directly optimizes hierarchy preservation by penalizing deviations from ground truth tree structure.
forward(embeddings, codes, lorentz_distance_fn)
¶
Compute hierarchy preservation loss.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
embeddings
|
Tensor
|
Hyperbolic embeddings (N, D+1) |
required |
codes
|
List[str]
|
List of NAICS codes corresponding to embeddings |
required |
lorentz_distance_fn
|
Callable[[Tensor, Tensor], Tensor]
|
Function to compute Lorentz distances |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
Loss scalar |