Source code for easygraph.ml_metrics.base

import abc

from collections import defaultdict
from functools import partial
from typing import Dict
from typing import List
from typing import Union

import torch

from easygraph._global import AUTHOR_EMAIL


[docs]def format_metric_configs(task: str, metric_configs: List[Union[str, Dict[str, dict]]]): r"""Format metric_configs. Args: ``task`` (``str``): The type of the task. The supported types include: ``classification``, ``retrieval`` and ``recommender``. ``metric_configs`` (``Dict[str, Dict[str, Union[str, int]]]``): The metric configs. """ task = task.lower() if task == "classification": import easygraph.ml_metrics.classification as module available_metrics = module.available_classification_metrics() else: raise ValueError( f"Task {task} is not supported yet. Please email '{AUTHOR_EMAIL}' to" " add it." ) metric_list = [] for metric in metric_configs: if isinstance(metric, str): marker, func_name = metric, metric assert func_name in available_metrics, ( f"{func_name} is not supported yet. Please email '{AUTHOR_EMAIL}' to" " add it." ) func = getattr(module, func_name) elif isinstance(metric, dict): assert len(metric) == 1 func_name = list(metric.keys())[0] assert func_name in available_metrics, ( f"{func_name} is not supported yet. Please email '{AUTHOR_EMAIL}' to" " add it." ) params = metric[func_name] func = getattr(module, func_name) func = partial(func, **params) markder_list = [] for k, v in params.items(): _m = f"{k}@" if isinstance(v, str): _m += v elif isinstance(v, int): _m += str(v) elif isinstance(v, float): _m += f"{v:.4f}" elif isinstance(v, list) or isinstance(v, tuple) or isinstance(v, set): _m += "_".join([str(_v) for _v in v]) else: _m += str(v) markder_list.append(_m) marker = f"{func_name} -> {' | '.join(markder_list)}" else: raise ValueError metric_list.append({"marker": marker, "func": func, "func_name": func_name}) return metric_list
[docs]class BaseEvaluator: r"""The base class for task-specified metric evaluators. Args: ``task`` (``str``): The type of the task. The supported types include: ``classification``, ``retrieval`` and ``recommender``. ``metric_configs`` (``List[Union[str, Dict[str, dict]]]``): The metric configurations. The key is the metric name and the value is the metric parameters. ``validate_index`` (``int``): The specified metric index used for validation. Defaults to ``0``. """ def __init__( self, task: str, metric_configs: List[Union[str, Dict[str, dict]]], validate_index: int = 0, ): self.validate_index = validate_index metric_configs = format_metric_configs(task, metric_configs) assert validate_index >= 0 and validate_index < len( metric_configs ), "The specified validate metric index is out of range." self.marker_list, self.func_list = [], [] for metric in metric_configs: self.marker_list.append(metric["marker"]) self.func_list.append(metric["func"]) # init batch data containers self.validate_res = [] self.test_res_dict = defaultdict(list) self.last_validate_res, self.last_test_res = None, {} @abc.abstractmethod def __repr__(self) -> str: r"""Print the Evaluator information."""
[docs] def validate_add_batch( self, batch_y_true: torch.Tensor, batch_y_pred: torch.Tensor ): import numpy as np r"""Add batch data for validation. Args: ``batch_y_true`` (``torch.Tensor``): The ground truth data. Size :math:`(N_{batch}, -)`. ``batch_y_pred`` (``torch.Tensor``): The predicted data. Size :math:`(N_{batch}, -)`. """ batch_res = self.func_list[self.validate_index]( batch_y_true, batch_y_pred, ret_batch=True ) batch_res = np.array(batch_res) if len(batch_res.shape) == 1: batch_res = batch_res[:, np.newaxis] self.validate_res.append(batch_res)
[docs] def validate_epoch_res(self): r"""For all added batch data, return the result of the evaluation on the specified ``validate_index``-th metric. """ import numpy as np if self.validate_res == [] and self.last_validate_res is not None: return self.last_validate_res assert self.validate_res != [], "No batch data added for validation." self.last_validate_res = np.vstack(self.validate_res).mean(0).item() # clear batch cache self.validate_res = [] return self.last_validate_res
[docs] def test_add_batch(self, batch_y_true: torch.Tensor, batch_y_pred: torch.Tensor): r"""Add batch data for testing. Args: ``batch_y_true`` (``torch.Tensor``): The ground truth data. Size :math:`(N_{batch}, -)`. ``batch_y_pred`` (``torch.Tensor``): The predicted data. Size :math:`(N_{batch}, -)`. """ import numpy as np for name, func in zip(self.marker_list, self.func_list): batch_res = func(batch_y_true, batch_y_pred, ret_batch=True) if not isinstance(batch_res, tuple): batch_res = np.array(batch_res) if len(batch_res.shape) == 1: batch_res = batch_res[:, np.newaxis] self.test_res_dict[name].append(batch_res) else: if self.test_res_dict[name] == []: self.test_res_dict[name] = [list() for _ in range(len(batch_res))] for idx, batch_sub_res in enumerate(batch_res): batch_sub_res = np.array(batch_sub_res) if len(batch_sub_res.shape) == 1: batch_sub_res = batch_sub_res[:, np.newaxis] self.test_res_dict[name][idx].append(batch_sub_res)
[docs] def test_epoch_res(self): r"""For all added batch data, return results of the evaluation on all the ml_metrics in ``metric_configs``. """ import numpy as np if self.test_res_dict == {} and self.last_test_res is not None: return self.last_test_res assert self.test_res_dict != {}, "No batch data added for testing." for name, res_list in self.test_res_dict.items(): if not isinstance(res_list[0], list): self.last_test_res[name] = ( np.vstack(res_list).mean(0).squeeze().tolist() ) else: self.last_test_res[name] = [ np.vstack(sub_res_list).mean(0).squeeze().tolist() for sub_res_list in res_list ] # clear batch cache self.test_res_dict = defaultdict(list) return self.last_test_res
[docs] def validate(self, y_true: torch.LongTensor, y_pred: torch.Tensor): r"""Return the result of the evaluation on the specified ``validate_index``-th metric. Args: ``y_true`` (``torch.LongTensor``): The ground truth labels. Size :math:`(N_{samples}, -)`. ``y_pred`` (``torch.Tensor``): The predicted labels. Size :math:`(N_{samples}, -)`. """ return self.func_list[self.validate_index](y_true, y_pred)
[docs] def test(self, y_true: torch.LongTensor, y_pred: torch.Tensor): r"""Return results of the evaluation on all the ml_metrics in ``metric_configs``. Args: ``y_true`` (``torch.LongTensor``): The ground truth labels. Size :math:`(N_{samples}, -)`. ``y_pred`` (``torch.Tensor``): The predicted labels. Size :math:`(N_{samples}, -)`. """ return { name: func(y_true, y_pred) for name, func in zip(self.marker_list, self.func_list) }