easygraph.ml_metrics.hypergraphs.hypergraph module#
- class easygraph.ml_metrics.hypergraphs.hypergraph.HypergraphVertexClassificationEvaluator(metric_configs: List[str | Dict[str, dict]], validate_index: int = 0)[source]#
Bases:
VertexClassificationEvaluator
Return the metric evaluator for vertex classification task on the hypergraph structure. The supported ml_metrics includes:
accuracy
,f1_score
,confusion_matrix
.- Parameters:
metric_configs (
List[Union[str, Dict[str, dict]]]
) – The metric configurations. The key is the metric name and the value is the metric parameters.validate_index (
int
) – The specified metric index used for validation. Defaults to0
.
Examples
>>> import torch >>> import easygraph.ml_metrics as dm >>> evaluator = dm.HypergraphVertexClassificationEvaluator( [ "accuracy", {"f1_score": {"average": "macro"}}, ], 0 ) >>> y_true = torch.tensor([0, 0, 1, 1, 2, 2]) >>> y_pred = torch.tensor([0, 2, 1, 2, 1, 2]) >>> evaluator.validate(y_true, y_pred) 0.5 >>> evaluator.test(y_true, y_pred) { 'accuracy': 0.5, 'f1_score -> average@macro': 0.5222222222222221 }
Methods
test
(y_true, y_pred)Return results of the evaluation on all the ml_metrics in
metric_configs
.test_add_batch
(batch_y_true, batch_y_pred)Add batch data for testing.
test_epoch_res
()For all added batch data, return results of the evaluation on all the ml_metrics in
metric_configs
.validate
(y_true, y_pred)Return the result of the evaluation on the specified
validate_index
-th metric.validate_epoch_res
()For all added batch data, return the result of the evaluation on the specified
validate_index
-th metric.validate_add_batch
- test(y_true: LongTensor, y_pred: Tensor)[source]#
Return results of the evaluation on all the ml_metrics in
metric_configs
.- Parameters:
y_true (
torch.LongTensor
) – The ground truth labels. Size \((N_{samples}, )\).y_pred (
torch.Tensor
) – The predicted labels. Size \((N_{samples}, N_{class})\) or \((N_{samples}, )\).
- validate(y_true: LongTensor, y_pred: Tensor)[source]#
Return the result of the evaluation on the specified
validate_index
-th metric.- Parameters:
y_true (
torch.LongTensor
) – The ground truth labels. Size \((N_{samples}, )\).y_pred (
torch.Tensor
) – The predicted labels. Size \((N_{samples}, N_{class})\) or \((N_{samples}, )\).