easygraph.ml_metrics.base module#
- class easygraph.ml_metrics.base.BaseEvaluator(task: str, metric_configs: List[str | Dict[str, dict]], validate_index: int = 0)[source]#
Bases:
objectThe base class for task-specified metric evaluators.
- Parameters:
task (
str) – The type of the task. The supported types include:classification,retrievalandrecommender.metric_configs (
List[Union[str, Dict[str, dict]]]) – The metric configurations. The key is the metric name and the value is the metric parameters.validate_index (
int) – The specified metric index used for validation. Defaults to0.
- test(y_true: torch.LongTensor, y_pred: torch.Tensor)[source]#
Return results of the evaluation on all the ml_metrics in
metric_configs.- Parameters:
y_true (
torch.LongTensor) – The ground truth labels. Size \((N_{samples}, -)\).y_pred (
torch.Tensor) – The predicted labels. Size \((N_{samples}, -)\).
- test_add_batch(batch_y_true: torch.Tensor, batch_y_pred: torch.Tensor)[source]#
Add batch data for testing.
- Parameters:
batch_y_true (
torch.Tensor) – The ground truth data. Size \((N_{batch}, -)\).batch_y_pred (
torch.Tensor) – The predicted data. Size \((N_{batch}, -)\).
- test_epoch_res()[source]#
For all added batch data, return results of the evaluation on all the ml_metrics in
metric_configs.
- validate(y_true: torch.LongTensor, y_pred: torch.Tensor)[source]#
Return the result of the evaluation on the specified
validate_index-th metric.- Parameters:
y_true (
torch.LongTensor) – The ground truth labels. Size \((N_{samples}, -)\).y_pred (
torch.Tensor) – The predicted labels. Size \((N_{samples}, -)\).
- easygraph.ml_metrics.base.format_metric_configs(task: str, metric_configs: List[str | Dict[str, dict]])[source]#
Format metric_configs.
- Parameters:
task (
str) – The type of the task. The supported types include:classification,retrievalandrecommender.metric_configs (
Dict[str, Dict[str, Union[str, int]]]) – The metric configs.