Hierarchy Preservation Loss

Bases: Module

Loss component that encourages embedding distances to match tree distances. This directly optimizes hierarchy preservation by penalizing deviations from ground truth tree structure.

forward(embeddings, codes, lorentz_distance_fn)

Compute hierarchy preservation loss.

Parameters:

Name Type Description Default
embeddings Tensor

Hyperbolic embeddings (N, D+1)

required
codes List[str]

List of NAICS codes corresponding to embeddings

required
lorentz_distance_fn Callable[[Tensor, Tensor], Tensor]

Function to compute Lorentz distances

required

Returns:

Type Description
Tensor

Loss scalar