Source code for easygraph.ml_metrics.hypergraphs.hypergraph
from typing import Dict
from typing import List
from typing import Union
import torch
from ..classification import VertexClassificationEvaluator
[docs]class HypergraphVertexClassificationEvaluator(VertexClassificationEvaluator):
    r"""Return the metric evaluator for vertex classification task on the hypergraph structure. The supported ml_metrics includes: ``accuracy``, ``f1_score``, ``confusion_matrix``.
    Args:
        ``metric_configs`` (``List[Union[str, Dict[str, dict]]]``): The metric configurations. The key is the metric name and the value is the metric parameters.
        ``validate_index`` (``int``): The specified metric index used for validation. Defaults to ``0``.
    Examples:
        >>> import torch
        >>> import easygraph.ml_metrics as dm
        >>> evaluator = dm.HypergraphVertexClassificationEvaluator(
                [
                    "accuracy",
                    {"f1_score": {"average": "macro"}},
                ],
                0
            )
        >>> y_true = torch.tensor([0, 0, 1, 1, 2, 2])
        >>> y_pred = torch.tensor([0, 2, 1, 2, 1, 2])
        >>> evaluator.validate(y_true, y_pred)
        0.5
        >>> evaluator.test(y_true, y_pred)
        {
            'accuracy': 0.5,
            'f1_score -> average@macro': 0.5222222222222221
        }
    """
    def __init__(
        self, metric_configs: List[Union[str, Dict[str, dict]]], validate_index: int = 0
    ):
        super().__init__(metric_configs, validate_index)
[docs]    def validate(self, y_true: torch.LongTensor, y_pred: torch.Tensor):
        return super().validate(y_true, y_pred) 
[docs]    def test(self, y_true: torch.LongTensor, y_pred: torch.Tensor):
        return super().test(y_true, y_pred)