easygraph.ml_metrics.base module#

class easygraph.ml_metrics.base.BaseEvaluator(task: str, metric_configs: List[str | Dict[str, dict]], validate_index: int = 0)[source]#

Bases: object

The base class for task-specified metric evaluators.

Parameters:
  • task (str) – The type of the task. The supported types include: classification, retrieval and recommender.

  • metric_configs (List[Union[str, Dict[str, dict]]]) – The metric configurations. The key is the metric name and the value is the metric parameters.

  • validate_index (int) – The specified metric index used for validation. Defaults to 0.

Methods

test(y_true, y_pred)

Return results of the evaluation on all the ml_metrics in metric_configs.

test_add_batch(batch_y_true, batch_y_pred)

Add batch data for testing.

test_epoch_res()

For all added batch data, return results of the evaluation on all the ml_metrics in metric_configs.

validate(y_true, y_pred)

Return the result of the evaluation on the specified validate_index-th metric.

validate_epoch_res()

For all added batch data, return the result of the evaluation on the specified validate_index-th metric.

validate_add_batch

test(y_true: LongTensor, y_pred: Tensor)[source]#

Return results of the evaluation on all the ml_metrics in metric_configs.

Parameters:
  • y_true (torch.LongTensor) – The ground truth labels. Size \((N_{samples}, -)\).

  • y_pred (torch.Tensor) – The predicted labels. Size \((N_{samples}, -)\).

test_add_batch(batch_y_true: Tensor, batch_y_pred: Tensor)[source]#

Add batch data for testing.

Parameters:
  • batch_y_true (torch.Tensor) – The ground truth data. Size \((N_{batch}, -)\).

  • batch_y_pred (torch.Tensor) – The predicted data. Size \((N_{batch}, -)\).

test_epoch_res()[source]#

For all added batch data, return results of the evaluation on all the ml_metrics in metric_configs.

validate(y_true: LongTensor, y_pred: Tensor)[source]#

Return the result of the evaluation on the specified validate_index-th metric.

Parameters:
  • y_true (torch.LongTensor) – The ground truth labels. Size \((N_{samples}, -)\).

  • y_pred (torch.Tensor) – The predicted labels. Size \((N_{samples}, -)\).

validate_add_batch(batch_y_true: Tensor, batch_y_pred: Tensor)[source]#
validate_epoch_res()[source]#

For all added batch data, return the result of the evaluation on the specified validate_index-th metric.

easygraph.ml_metrics.base.format_metric_configs(task: str, metric_configs: List[str | Dict[str, dict]])[source]#

Format metric_configs.

Parameters:
  • task (str) – The type of the task. The supported types include: classification, retrieval and recommender.

  • metric_configs (Dict[str, Dict[str, Union[str, int]]]) – The metric configs.