import abc
from collections import defaultdict
from functools import partial
from typing import Dict
from typing import List
from typing import Union
import torch
from easygraph._global import AUTHOR_EMAIL
[docs]class BaseEvaluator:
r"""The base class for task-specified metric evaluators.
Args:
``task`` (``str``): The type of the task. The supported types include: ``classification``, ``retrieval`` and ``recommender``.
``metric_configs`` (``List[Union[str, Dict[str, dict]]]``): The metric configurations. The key is the metric name and the value is the metric parameters.
``validate_index`` (``int``): The specified metric index used for validation. Defaults to ``0``.
"""
def __init__(
self,
task: str,
metric_configs: List[Union[str, Dict[str, dict]]],
validate_index: int = 0,
):
self.validate_index = validate_index
metric_configs = format_metric_configs(task, metric_configs)
assert validate_index >= 0 and validate_index < len(
metric_configs
), "The specified validate metric index is out of range."
self.marker_list, self.func_list = [], []
for metric in metric_configs:
self.marker_list.append(metric["marker"])
self.func_list.append(metric["func"])
# init batch data containers
self.validate_res = []
self.test_res_dict = defaultdict(list)
self.last_validate_res, self.last_test_res = None, {}
@abc.abstractmethod
def __repr__(self) -> str:
r"""Print the Evaluator information."""
[docs] def validate_add_batch(
self, batch_y_true: torch.Tensor, batch_y_pred: torch.Tensor
):
import numpy as np
r"""Add batch data for validation.
Args:
``batch_y_true`` (``torch.Tensor``): The ground truth data. Size :math:`(N_{batch}, -)`.
``batch_y_pred`` (``torch.Tensor``): The predicted data. Size :math:`(N_{batch}, -)`.
"""
batch_res = self.func_list[self.validate_index](
batch_y_true, batch_y_pred, ret_batch=True
)
batch_res = np.array(batch_res)
if len(batch_res.shape) == 1:
batch_res = batch_res[:, np.newaxis]
self.validate_res.append(batch_res)
[docs] def validate_epoch_res(self):
r"""For all added batch data, return the result of the evaluation on the specified ``validate_index``-th metric.
"""
import numpy as np
if self.validate_res == [] and self.last_validate_res is not None:
return self.last_validate_res
assert self.validate_res != [], "No batch data added for validation."
self.last_validate_res = np.vstack(self.validate_res).mean(0).item()
# clear batch cache
self.validate_res = []
return self.last_validate_res
[docs] def test_add_batch(self, batch_y_true: torch.Tensor, batch_y_pred: torch.Tensor):
r"""Add batch data for testing.
Args:
``batch_y_true`` (``torch.Tensor``): The ground truth data. Size :math:`(N_{batch}, -)`.
``batch_y_pred`` (``torch.Tensor``): The predicted data. Size :math:`(N_{batch}, -)`.
"""
import numpy as np
for name, func in zip(self.marker_list, self.func_list):
batch_res = func(batch_y_true, batch_y_pred, ret_batch=True)
if not isinstance(batch_res, tuple):
batch_res = np.array(batch_res)
if len(batch_res.shape) == 1:
batch_res = batch_res[:, np.newaxis]
self.test_res_dict[name].append(batch_res)
else:
if self.test_res_dict[name] == []:
self.test_res_dict[name] = [list() for _ in range(len(batch_res))]
for idx, batch_sub_res in enumerate(batch_res):
batch_sub_res = np.array(batch_sub_res)
if len(batch_sub_res.shape) == 1:
batch_sub_res = batch_sub_res[:, np.newaxis]
self.test_res_dict[name][idx].append(batch_sub_res)
[docs] def test_epoch_res(self):
r"""For all added batch data, return results of the evaluation on all the ml_metrics in ``metric_configs``.
"""
import numpy as np
if self.test_res_dict == {} and self.last_test_res is not None:
return self.last_test_res
assert self.test_res_dict != {}, "No batch data added for testing."
for name, res_list in self.test_res_dict.items():
if not isinstance(res_list[0], list):
self.last_test_res[name] = (
np.vstack(res_list).mean(0).squeeze().tolist()
)
else:
self.last_test_res[name] = [
np.vstack(sub_res_list).mean(0).squeeze().tolist()
for sub_res_list in res_list
]
# clear batch cache
self.test_res_dict = defaultdict(list)
return self.last_test_res
[docs] def validate(self, y_true: torch.LongTensor, y_pred: torch.Tensor):
r"""Return the result of the evaluation on the specified ``validate_index``-th metric.
Args:
``y_true`` (``torch.LongTensor``): The ground truth labels. Size :math:`(N_{samples}, -)`.
``y_pred`` (``torch.Tensor``): The predicted labels. Size :math:`(N_{samples}, -)`.
"""
return self.func_list[self.validate_index](y_true, y_pred)
[docs] def test(self, y_true: torch.LongTensor, y_pred: torch.Tensor):
r"""Return results of the evaluation on all the ml_metrics in ``metric_configs``.
Args:
``y_true`` (``torch.LongTensor``): The ground truth labels. Size :math:`(N_{samples}, -)`.
``y_pred`` (``torch.Tensor``): The predicted labels. Size :math:`(N_{samples}, -)`.
"""
return {
name: func(y_true, y_pred)
for name, func in zip(self.marker_list, self.func_list)
}