easygraph.ml_metrics.hypergraphs.hypergraph module#

class easygraph.ml_metrics.hypergraphs.hypergraph.HypergraphVertexClassificationEvaluator(metric_configs: List[str | Dict[str, dict]], validate_index: int = 0)[source]#

Bases: VertexClassificationEvaluator

Return the metric evaluator for vertex classification task on the hypergraph structure. The supported ml_metrics includes: accuracy, f1_score, confusion_matrix.

Parameters:
  • metric_configs (List[Union[str, Dict[str, dict]]]) – The metric configurations. The key is the metric name and the value is the metric parameters.

  • validate_index (int) – The specified metric index used for validation. Defaults to 0.

Examples

>>> import torch
>>> import easygraph.ml_metrics as dm
>>> evaluator = dm.HypergraphVertexClassificationEvaluator(
        [
            "accuracy",
            {"f1_score": {"average": "macro"}},
        ],
        0
    )
>>> y_true = torch.tensor([0, 0, 1, 1, 2, 2])
>>> y_pred = torch.tensor([0, 2, 1, 2, 1, 2])
>>> evaluator.validate(y_true, y_pred)
0.5
>>> evaluator.test(y_true, y_pred)
{
    'accuracy': 0.5,
    'f1_score -> average@macro': 0.5222222222222221
}

Methods

test(y_true, y_pred)

Return results of the evaluation on all the ml_metrics in metric_configs.

test_add_batch(batch_y_true, batch_y_pred)

Add batch data for testing.

test_epoch_res()

For all added batch data, return results of the evaluation on all the ml_metrics in metric_configs.

validate(y_true, y_pred)

Return the result of the evaluation on the specified validate_index-th metric.

validate_epoch_res()

For all added batch data, return the result of the evaluation on the specified validate_index-th metric.

validate_add_batch

test(y_true: LongTensor, y_pred: Tensor)[source]#

Return results of the evaluation on all the ml_metrics in metric_configs.

Parameters:
  • y_true (torch.LongTensor) – The ground truth labels. Size \((N_{samples}, )\).

  • y_pred (torch.Tensor) – The predicted labels. Size \((N_{samples}, N_{class})\) or \((N_{samples}, )\).

validate(y_true: LongTensor, y_pred: Tensor)[source]#

Return the result of the evaluation on the specified validate_index-th metric.

Parameters:
  • y_true (torch.LongTensor) – The ground truth labels. Size \((N_{samples}, )\).

  • y_pred (torch.Tensor) – The predicted labels. Size \((N_{samples}, N_{class})\) or \((N_{samples}, )\).