Hyperbolic API

HyperbolicProjection

Bases: Module

Projects Euclidean embeddings to the Lorentz model of hyperbolic space.

The Lorentz model represents points as (x₀, x₁, ..., xₙ) where: - x₀ is the time coordinate (hyperbolic radius) - x₁...xₙ are spatial coordinates - Constraint: -x₀² + x₁² + ... + xₙ² = -1/c (Lorentz inner product)

exp_map_zero(v)

Exponential map from tangent space at origin to Lorentz hyperboloid.

The output satisfies the Lorentz constraint: ||x_spatial||^2 - x0^2 = -1/c

Uses compiled operations when torch.compile is enabled.

Parameters:

Name Type Description Default
v Tensor

Tangent vector of shape (batch_size, input_dim + 1)

required

Returns:

Type Description
Tensor

Point on Lorentz hyperboloid of shape (batch_size, input_dim + 1)

forward(euclidean_embedding)

Project Euclidean embedding to Lorentz hyperboloid.

Parameters:

Name Type Description Default
euclidean_embedding Tensor

Euclidean embedding of shape (batch_size, input_dim)

required

Returns:

Type Description
Tensor

Hyperbolic embedding in Lorentz model of shape (batch_size, input_dim + 1)

LorentzDistance

Bases: Module

Computes distances in the Lorentz model of hyperbolic space.

Distance between two points u, v on the hyperboloid: d(u, v) = √c * arccosh(-⟨u, v⟩_L)

where ⟨u, v⟩_L = u₁v₁ + ... + uₙvₙ - u₀v₀ (Lorentz inner product)

batched_forward(u, v)

Batched Lorentz distance computation with broadcasting support.

Uses compiled operations when torch.compile is enabled.

Parameters:

Name Type Description Default
u Tensor

Tensor of shape (batch_size, 1, embedding_dim+1) or (batch_size, embedding_dim+1)

required
v Tensor

Tensor of shape (batch_size, k, embedding_dim+1)

required

Returns:

Type Description
Tensor

Tensor of shape (batch_size, k) with distances

forward(u, v)

Compute Lorentzian distance between two points.

Uses compiled operations when torch.compile is enabled.

Parameters:

Name Type Description Default
u Tensor

First point on hyperboloid, shape (batch_size, embedding_dim+1)

required
v Tensor

Second point on hyperboloid, shape (batch_size, embedding_dim+1)

required

Returns:

Type Description
Tensor

Distances, shape (batch_size,)

lorentz_dot(u, v)

Compute Lorentz inner product: ⟨u, v⟩_L = Σᵢ uᵢvᵢ - u₀v₀

Uses compiled operations when torch.compile is enabled.

Parameters:

Name Type Description Default
u Tensor

First point on hyperboloid, shape (batch_size, embedding_dim+1)

required
v Tensor

Second point on hyperboloid, shape (batch_size, embedding_dim+1)

required

Returns:

Type Description
Tensor

Lorentz inner products, shape (batch_size,)

LorentzOps

Static utility class for Lorentz model operations. Provides functions for mapping between hyperboloid and tangent space, and computing distances.

When torch.compile is enabled (PyTorch 2.0+), core operations are compiled for better throughput through kernel fusion.

exp_map_zero(x_tan, c=1.0) staticmethod

Exponential map from tangent space at origin to hyperboloid.

Maps a tangent vector at the origin to a point on the Lorentz hyperboloid. The output satisfies the Lorentz constraint: ||x_spatial||^2 - x0^2 = -1/c

Uses compiled operations when torch.compile is enabled.

Parameters:

Name Type Description Default
x_tan Tensor

Tangent vector, shape (batch_size, embedding_dim+1)

required
c float

Curvature parameter (default: 1.0)

1.0

Returns:

Type Description
Tensor

Point on hyperboloid, shape (batch_size, embedding_dim+1)

log_map_zero(x_hyp, c=1.0) staticmethod

Logarithmic map from hyperboloid to tangent space at origin.

Maps a point on the Lorentz hyperboloid to the tangent space at the origin. Inverse of exp_map_zero.

Uses compiled operations when torch.compile is enabled.

Parameters:

Name Type Description Default
x_hyp Tensor

Point on hyperboloid, shape (batch_size, embedding_dim+1) Must satisfy ||x_spatial||^2 - x0^2 = -1/c

required
c float

Curvature parameter (default: 1.0)

1.0

Returns:

Type Description
Tensor

Tangent vector, shape (batch_size, embedding_dim+1)

lorentz_distance(u, v, c=1.0) staticmethod

Compute Lorentzian distance between two points on the hyperboloid.

Uses compiled operations when torch.compile is enabled.

Parameters:

Name Type Description Default
u Tensor

First point on hyperboloid, shape (batch_size, embedding_dim+1)

required
v Tensor

Second point on hyperboloid, shape (batch_size, embedding_dim+1)

required
c float

Curvature parameter (default: 1.0)

1.0

Returns:

Type Description
Tensor

Distances, shape (batch_size,)

check_lorentz_manifold_validity(embeddings, curvature=1.0, tolerance=0.001)

Check if embeddings satisfy the Lorentz hyperboloid constraint.

For valid points: -x₀² + x₁² + ... + xₙ² = -1/c

Parameters:

Name Type Description Default
embeddings Tensor

Hyperbolic embeddings of shape (batch_size, embedding_dim+1)

required
curvature float

Curvature parameter c

1.0
tolerance float

Tolerance for constraint violation

0.001

Returns:

Type Description
Tuple[bool, Tensor, Tensor]

Tuple of: - is_valid: Boolean indicating if all points are valid - lorentz_norms: Lorentz inner product for each point (should be -1/c) - violations: Magnitude of constraint violations

compute_hyperbolic_radii(embeddings)

Extract hyperbolic radii (time coordinates) from Lorentz embeddings.

The time coordinate x₀ represents the hyperbolic radius (distance from origin).

Parameters:

Name Type Description Default
embeddings Tensor

Hyperbolic embeddings of shape (batch_size, embedding_dim+1)

required

Returns:

Type Description
Tensor

Hyperbolic radii of shape (batch_size,)

log_hyperbolic_diagnostics(embeddings, curvature=1.0, level_labels=None, logger_instance=None)

Log comprehensive diagnostics for hyperbolic embeddings.

Parameters:

Name Type Description Default
embeddings Tensor

Hyperbolic embeddings of shape (batch_size, embedding_dim+1)

required
curvature float

Curvature parameter c

1.0
level_labels Optional[Tensor]

Optional NAICS hierarchy level labels for grouped statistics

None
logger_instance Optional[Logger]

Optional logger instance (uses module logger if None)

None

Returns:

Type Description
Dict[str, float]

Dictionary of diagnostic metrics