Graph Metrics API

Graph-specific metrics and downstream evaluation.

Contains: - compute_validation_metrics: Triplet-based validation metrics for hyperbolic embeddings - GraphEmbeddingDataset: Container for hyperbolic graph embeddings - GraphDownstreamEvaluator: Downstream evaluation (taxonomy, similarity, clustering, classification) - run_graph_downstream_suite: Convenience helper to run all evaluations

GraphDownstreamEvaluator

Evaluate downstream metrics for graph-refined NAICS embeddings.

classification_benchmark(*, digits=2, test_size=0.2, random_state=42)

Train a linear classifier to predict NAICS prefixes.

clustering_quality(*, digits=(2, 3), random_state=42)

Compute ARI/NMI by clustering embeddings and comparing to NAICS prefixes.

industry_similarity(hierarchy, *, k_values=(5, 10))

Measure whether nearest neighbours correspond to semantic siblings.

taxonomy_reconstruction(hierarchy, *, k_values=(1, 3, 5))

Evaluate parent retrieval accuracy from embeddings alone.

GraphEmbeddingDataset dataclass

Container for a set of hyperbolic graph embeddings.

from_dataframe(frame, *, embedding_prefix='hgcn_e', code_column='code', level_column='level') classmethod

Build a dataset from a parquet dataframe.

compute_validation_metrics(emb, anchors, positives, negatives, c=1.0, top_k=1, *, as_tensors=False)

Compute validation metrics for hyperbolic embeddings.

Parameters:

Name Type Description Default
emb Tensor

Embeddings tensor of shape (N, embedding_dim+1).

required
anchors Tensor

Anchor indices, shape (batch_size,).

required
positives Tensor

Positive indices, shape (batch_size,).

required
negatives Tensor

Negative indices, shape (batch_size, k_negatives).

required
c float

Curvature parameter (default: 1.0).

1.0
top_k int

Number of top negatives to consider for auxiliary accuracy.

1
as_tensors bool

Return torch scalars instead of Python floats (for Lightning logging).

False

Returns:

Type Description
Union[Dict[str, float], Dict[str, Tensor]]

Mapping of metric names to values (either tensors or floats).

run_graph_downstream_suite(dataset, hierarchy, *, curvature=1.0, taxonomy_k=(1, 3, 5), sibling_k=(5, 10), clustering_digits=(2, 3), classification_digits=2, random_state=42)

Convenience helper to run all downstream evaluations sequentially.