easygraph.ml_metrics.base module#
- class easygraph.ml_metrics.base.BaseEvaluator(task: str, metric_configs: List[str | Dict[str, dict]], validate_index: int = 0)[source]#
Bases:
object
The base class for task-specified metric evaluators.
- Parameters:
task (
str
) – The type of the task. The supported types include:classification
,retrieval
andrecommender
.metric_configs (
List[Union[str, Dict[str, dict]]]
) – The metric configurations. The key is the metric name and the value is the metric parameters.validate_index (
int
) – The specified metric index used for validation. Defaults to0
.
Methods
test
(y_true, y_pred)Return results of the evaluation on all the ml_metrics in
metric_configs
.test_add_batch
(batch_y_true, batch_y_pred)Add batch data for testing.
For all added batch data, return results of the evaluation on all the ml_metrics in
metric_configs
.validate
(y_true, y_pred)Return the result of the evaluation on the specified
validate_index
-th metric.For all added batch data, return the result of the evaluation on the specified
validate_index
-th metric.validate_add_batch
- test(y_true: LongTensor, y_pred: Tensor)[source]#
Return results of the evaluation on all the ml_metrics in
metric_configs
.- Parameters:
y_true (
torch.LongTensor
) – The ground truth labels. Size \((N_{samples}, -)\).y_pred (
torch.Tensor
) – The predicted labels. Size \((N_{samples}, -)\).
- test_add_batch(batch_y_true: Tensor, batch_y_pred: Tensor)[source]#
Add batch data for testing.
- Parameters:
batch_y_true (
torch.Tensor
) – The ground truth data. Size \((N_{batch}, -)\).batch_y_pred (
torch.Tensor
) – The predicted data. Size \((N_{batch}, -)\).
- test_epoch_res()[source]#
For all added batch data, return results of the evaluation on all the ml_metrics in
metric_configs
.
- easygraph.ml_metrics.base.format_metric_configs(task: str, metric_configs: List[str | Dict[str, dict]])[source]#
Format metric_configs.
- Parameters:
task (
str
) – The type of the task. The supported types include:classification
,retrieval
andrecommender
.metric_configs (
Dict[str, Dict[str, Union[str, int]]]
) – The metric configs.