easygraph.ml_metrics.classification module#

class easygraph.ml_metrics.classification.VertexClassificationEvaluator(metric_configs: List[str | Dict[str, dict]], validate_index: int = 0)[source]#

Bases: BaseEvaluator

Return the metric evaluator for vertex classification task. The supported ml_metrics includes: accuracy, f1_score, confusion_matrix.

Parameters:
  • metric_configs (List[Union[str, Dict[str, dict]]]) – The metric configurations. The key is the metric name and the value is the metric parameters.

  • validate_index (int) – The specified metric index used for validation. Defaults to 0.

Methods

test(y_true, y_pred)

Return results of the evaluation on all the ml_metrics in metric_configs.

test_add_batch(batch_y_true, batch_y_pred)

Add batch data for testing.

test_epoch_res()

For all added batch data, return results of the evaluation on all the ml_metrics in metric_configs.

validate(y_true, y_pred)

Return the result of the evaluation on the specified validate_index-th metric.

validate_epoch_res()

For all added batch data, return the result of the evaluation on the specified validate_index-th metric.

validate_add_batch

test(y_true: LongTensor, y_pred: Tensor)[source]#

Return results of the evaluation on all the ml_metrics in metric_configs.

Parameters:
  • y_true (torch.LongTensor) – The ground truth labels. Size \((N_{samples}, )\).

  • y_pred (torch.Tensor) – The predicted labels. Size \((N_{samples}, N_{class})\) or \((N_{samples}, )\).

validate(y_true: LongTensor, y_pred: Tensor)[source]#

Return the result of the evaluation on the specified validate_index-th metric.

Parameters:
  • y_true (torch.LongTensor) – The ground truth labels. Size \((N_{samples}, )\).

  • y_pred (torch.Tensor) – The predicted labels. Size \((N_{samples}, N_{class})\) or \((N_{samples}, )\).

easygraph.ml_metrics.classification.accuracy(y_true: LongTensor, y_pred: Tensor)[source]#

Calculate the accuracy score for the classification task.

\[\text{Accuracy} = \frac{1}{N} \sum_{i=1}^{N} \mathcal{I}(y_i, \hat{y}_i),\]

where \(\mathcal{I}(\cdot, \cdot)\) is the indicator function, which is 1 if the two inputs are equal, and 0 otherwise. \(y_i\) and \(\hat{y}_i\) are the ground truth and predicted labels for the i-th sample.

Parameters:
  • y_true (torch.LongTensor) – The ground truth labels. Size \((N_{samples}, )\).

  • y_pred (torch.Tensor) – The predicted labels. Size \((N_{samples}, N_{class})\) or \((N_{samples}, )\).

Examples

>>> import torch
>>> import easygraph.ml_metrics as dm
>>> y_true = torch.tensor([3, 2, 4])
>>> y_pred = torch.tensor([
        [0.2, 0.3, 0.5, 0.4, 0.3],
        [0.8, 0.2, 0.3, 0.5, 0.4],
        [0.2, 0.4, 0.5, 0.2, 0.8],
    ])
>>> dm.classification.accuracy(y_true, y_pred)
0.3333333432674408
easygraph.ml_metrics.classification.available_classification_metrics()[source]#

Return available ml_metrics for the classification task.

The available ml_metrics are: accuracy, f1_score, confusion_matrix.

easygraph.ml_metrics.classification.confusion_matrix(y_true: LongTensor, y_pred: Tensor)[source]#

Calculate the confusion matrix for the classification task.

Parameters:
  • y_true (torch.LongTensor) – The ground truth labels. Size \((N_{samples}, )\).

  • y_pred (torch.Tensor) – The predicted labels. Size \((N_{samples}, N_{class})\) or \((N_{samples}, )\).

Examples

>>> import torch
>>> import easygraph.ml_metrics as dm
>>> y_true = torch.tensor([3, 2, 4, 0])
>>> y_pred = torch.tensor([
        [0.2, 0.3, 0.5, 0.4, 0.3],
        [0.8, 0.2, 0.3, 0.5, 0.4],
        [0.2, 0.4, 0.5, 0.2, 0.8],
        [0.8, 0.4, 0.5, 0.2, 0.8]
    ])
>>> dm.classification.confusion_matrix(y_true, y_pred)
array([[1, 0, 0, 0],
       [1, 0, 0, 0],
       [0, 1, 0, 0],
       [0, 0, 0, 1]])
easygraph.ml_metrics.classification.f1_score(y_true: LongTensor, y_pred: Tensor, average: str = 'macro')[source]#

Calculate the F1 score for the classification task.

Parameters:
  • y_true (torch.LongTensor) – The ground truth labels. Size \((N_{samples}, )\).

  • y_pred (torch.Tensor) – The predicted labels. Size \((N_{samples}, N_{class})\) or \((N_{samples}, )\).

  • average (str) – The average method. Must be one of “macro”, “micro”, “weighted”.

Examples

>>> import torch
>>> import easygraph.ml_metrics as dm
>>> y_true = torch.tensor([3, 2, 4, 0])
>>> y_pred = torch.tensor([
        [0.2, 0.3, 0.5, 0.4, 0.3],
        [0.8, 0.2, 0.3, 0.5, 0.4],
        [0.2, 0.4, 0.5, 0.2, 0.8],
        [0.8, 0.4, 0.5, 0.2, 0.8]
    ])
>>> dm.classification.f1_score(y_true, y_pred, "macro")
0.41666666666666663
>>> dm.classification.f1_score(y_true, y_pred, "micro")
0.5
>>> dm.classification.f1_score(y_true, y_pred, "weighted")
0.41666666666666663