Source code for easygraph.ml_metrics

try:
    from typing import Dict
    from typing import List
    from typing import Union

    from easygraph._global import AUTHOR_EMAIL

    from .base import BaseEvaluator
    from .classification import VertexClassificationEvaluator
    from .classification import available_classification_metrics
    from .hypergraphs import HypergraphVertexClassificationEvaluator
except:
    print(
        "Warning raise in module:ml_metrics. Please install Pytorch before you use"
        " functions related to nueral network"
    )


[docs]def build_evaluator( task: str, metric_configs: List[Union[str, Dict[str, dict]]], validate_index: int = 0, ): r"""Return the metric evaluator for the given task. Args: ``task`` (``str``): The type of the task. The supported types include: ``graph_vertex_classification``, ``hypergraph_vertex_classification``, and ``user_item_recommender``. ``metric_configs`` (``List[Union[str, Dict[str, dict]]]``): The list of metric names. ``validate_index`` (``int``): The specified metric index used for validation. Defaults to ``0``. """ if task == "hypergraph_vertex_classification": return HypergraphVertexClassificationEvaluator(metric_configs, validate_index) else: raise ValueError( f"{task} is not supported yet. Please email '{AUTHOR_EMAIL}' to add it." )
# __all__ = [ # "BaseEvaluator", # "build_evaluator", # "available_classification_metrics", # "VertexClassificationEvaluator", # "HypergraphVertexClassificationEvaluator", # ]