easygraph.ml_metrics.classification module#
- class easygraph.ml_metrics.classification.VertexClassificationEvaluator(metric_configs: List[str | Dict[str, dict]], validate_index: int = 0)[source]#
Bases:
BaseEvaluator
Return the metric evaluator for vertex classification task. The supported ml_metrics includes:
accuracy
,f1_score
,confusion_matrix
.- Parameters:
metric_configs (
List[Union[str, Dict[str, dict]]]
) – The metric configurations. The key is the metric name and the value is the metric parameters.validate_index (
int
) – The specified metric index used for validation. Defaults to0
.
Methods
test
(y_true, y_pred)Return results of the evaluation on all the ml_metrics in
metric_configs
.test_add_batch
(batch_y_true, batch_y_pred)Add batch data for testing.
test_epoch_res
()For all added batch data, return results of the evaluation on all the ml_metrics in
metric_configs
.validate
(y_true, y_pred)Return the result of the evaluation on the specified
validate_index
-th metric.validate_epoch_res
()For all added batch data, return the result of the evaluation on the specified
validate_index
-th metric.validate_add_batch
- test(y_true: LongTensor, y_pred: Tensor)[source]#
Return results of the evaluation on all the ml_metrics in
metric_configs
.- Parameters:
y_true (
torch.LongTensor
) – The ground truth labels. Size \((N_{samples}, )\).y_pred (
torch.Tensor
) – The predicted labels. Size \((N_{samples}, N_{class})\) or \((N_{samples}, )\).
- validate(y_true: LongTensor, y_pred: Tensor)[source]#
Return the result of the evaluation on the specified
validate_index
-th metric.- Parameters:
y_true (
torch.LongTensor
) – The ground truth labels. Size \((N_{samples}, )\).y_pred (
torch.Tensor
) – The predicted labels. Size \((N_{samples}, N_{class})\) or \((N_{samples}, )\).
- easygraph.ml_metrics.classification.accuracy(y_true: LongTensor, y_pred: Tensor)[source]#
Calculate the accuracy score for the classification task.
\[\text{Accuracy} = \frac{1}{N} \sum_{i=1}^{N} \mathcal{I}(y_i, \hat{y}_i),\]where \(\mathcal{I}(\cdot, \cdot)\) is the indicator function, which is 1 if the two inputs are equal, and 0 otherwise. \(y_i\) and \(\hat{y}_i\) are the ground truth and predicted labels for the i-th sample.
- Parameters:
y_true (
torch.LongTensor
) – The ground truth labels. Size \((N_{samples}, )\).y_pred (
torch.Tensor
) – The predicted labels. Size \((N_{samples}, N_{class})\) or \((N_{samples}, )\).
Examples
>>> import torch >>> import easygraph.ml_metrics as dm >>> y_true = torch.tensor([3, 2, 4]) >>> y_pred = torch.tensor([ [0.2, 0.3, 0.5, 0.4, 0.3], [0.8, 0.2, 0.3, 0.5, 0.4], [0.2, 0.4, 0.5, 0.2, 0.8], ]) >>> dm.classification.accuracy(y_true, y_pred) 0.3333333432674408
- easygraph.ml_metrics.classification.available_classification_metrics()[source]#
Return available ml_metrics for the classification task.
The available ml_metrics are:
accuracy
,f1_score
,confusion_matrix
.
- easygraph.ml_metrics.classification.confusion_matrix(y_true: LongTensor, y_pred: Tensor)[source]#
Calculate the confusion matrix for the classification task.
- Parameters:
y_true (
torch.LongTensor
) – The ground truth labels. Size \((N_{samples}, )\).y_pred (
torch.Tensor
) – The predicted labels. Size \((N_{samples}, N_{class})\) or \((N_{samples}, )\).
Examples
>>> import torch >>> import easygraph.ml_metrics as dm >>> y_true = torch.tensor([3, 2, 4, 0]) >>> y_pred = torch.tensor([ [0.2, 0.3, 0.5, 0.4, 0.3], [0.8, 0.2, 0.3, 0.5, 0.4], [0.2, 0.4, 0.5, 0.2, 0.8], [0.8, 0.4, 0.5, 0.2, 0.8] ]) >>> dm.classification.confusion_matrix(y_true, y_pred) array([[1, 0, 0, 0], [1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 0, 1]])
- easygraph.ml_metrics.classification.f1_score(y_true: LongTensor, y_pred: Tensor, average: str = 'macro')[source]#
Calculate the F1 score for the classification task.
- Parameters:
y_true (
torch.LongTensor
) – The ground truth labels. Size \((N_{samples}, )\).y_pred (
torch.Tensor
) – The predicted labels. Size \((N_{samples}, N_{class})\) or \((N_{samples}, )\).average (
str
) – The average method. Must be one of “macro”, “micro”, “weighted”.
Examples
>>> import torch >>> import easygraph.ml_metrics as dm >>> y_true = torch.tensor([3, 2, 4, 0]) >>> y_pred = torch.tensor([ [0.2, 0.3, 0.5, 0.4, 0.3], [0.8, 0.2, 0.3, 0.5, 0.4], [0.2, 0.4, 0.5, 0.2, 0.8], [0.8, 0.4, 0.5, 0.2, 0.8] ]) >>> dm.classification.f1_score(y_true, y_pred, "macro") 0.41666666666666663 >>> dm.classification.f1_score(y_true, y_pred, "micro") 0.5 >>> dm.classification.f1_score(y_true, y_pred, "weighted") 0.41666666666666663