Source code for easygraph.experiments.vertex_classification

from pathlib import Path
from typing import Callable
from typing import Optional
from typing import Union

import optuna
import torch
import torch.nn as nn

from easygraph.ml_metrics import BaseEvaluator

from .base import BaseTask


[docs]class VertexClassificationTask(BaseTask): r"""The auto-experiment class for the vertex classification task. Args: ``work_root`` (``Optional[Union[str, Path]]``): User's work root to store all studies. ``data`` (``dict``): The dictionary to store input data that used in the experiment. ``model_builder`` (``Callable``): The function to build a model with a fixed parameter ``trial``. ``train_builder`` (``Callable``): The function to build a training configuration with two fixed parameters ``trial`` and ``model``. ``evaluator`` (``eg.ml_metrics.BaseEvaluator``): The DHG evaluator object to evaluate performance of the model in the experiment. ``device`` (``torch.device``): The target device to run the experiment. ``structure_builder`` (``Optional[Callable]``): The function to build a structure with a fixed parameter ``trial``. The structure can be ``eg.Hypergraph``. ``study_name`` (``Optional[str]``): The name of this study. If set to ``None``, the study name will be generated automatically according to current time. Defaults to ``None``. ``overwrite`` (``bool``): The flag that whether to overwrite the existing study. Different studies are identified by the ``study_name``. Defaults to ``True``. """ def __init__( self, work_root: Optional[Union[str, Path]], data: dict, model_builder: Callable, train_builder: Callable, evaluator: BaseEvaluator, device: torch.device, structure_builder: Optional[Callable] = None, study_name: Optional[str] = None, overwrite: bool = True, ): super().__init__( work_root, data, model_builder, train_builder, evaluator, device, structure_builder=structure_builder, study_name=study_name, overwrite=overwrite, ) self.to(self.device)
[docs] def to(self, device: torch.device): r"""Move the input data to the target device. Args: ``device`` (``torch.device``): The specified target device to store the input data. """ self.device = device for name in self.vars_for_DL: if name in self.data.keys(): self.data[name] = self.data[name].to(device) return self
@property def vars_for_DL(self): r"""Return a name list for available variables for deep learning in the vertex classification task. The name list includes ``features``, ``structure``, ``labels``, ``train_mask``, ``val_mask``, and ``test_mask``. """ return ( "features", "structure", "labels", "train_mask", "val_mask", "test_mask", )
[docs] def experiment(self, trial: optuna.Trial): r"""Run the experiment for a given trial. Args: ``trial`` (``optuna.Trial``): The ``optuna.Trial`` object. """ return super().experiment(trial)
[docs] def run(self, max_epoch: int, num_trials: int = 1, direction: str = "maximize"): r"""Run experiments with automatically hyper-parameter tuning. Args: ``max_epoch`` (``int``): The maximum number of epochs to train for each experiment. ``num_trials`` (``int``): The number of trials to run. Defaults to ``1``. ``direction`` (``str``): The direction to optimize. Defaults to ``"maximize"``. """ return super().run(max_epoch, num_trials, direction)
[docs] def train( self, data: dict, model: nn.Module, optimizer: torch.optim.Optimizer, criterion: nn.Module, ): r"""Train model for one epoch. Args: ``data`` (``dict``): The input data. ``model`` (``nn.Module``): The model. ``optimizer`` (``torch.optim.Optimizer``): The model optimizer. ``criterion`` (``nn.Module``): The loss function. """ features, structure = data["features"], data["structure"] train_mask, labels = data["train_mask"], data["labels"] model.train() optimizer.zero_grad() outputs = model(features, structure) loss = criterion( outputs[train_mask], labels[train_mask], ) loss.backward() optimizer.step()
[docs] @torch.no_grad() def validate(self, data: dict, model: nn.Module): r"""Validate the model. Args: ``data`` (``dict``): The input data. ``model`` (``nn.Module``): The model. """ features, structure = data["features"], data["structure"] val_mask, labels = data["val_mask"], data["labels"] model.eval() outputs = model(features, structure) res = self.evaluator.validate(labels[val_mask], outputs[val_mask]) return res
[docs] @torch.no_grad() def test(self, data: Optional[dict] = None, model: Optional[nn.Module] = None): r"""Test the model. Args: ``data`` (``dict``, optional): The input data if set to ``None``, the specified ``data`` in the initialization of the experiments will be used. Defaults to ``None``. ``model`` (``nn.Module``, optional): The model if set to ``None``, the trained best model will be used. Defaults to ``None``. """ if data is None: features, structure = self.data["features"], self.best_structure test_mask, labels = self.data["test_mask"], self.data["labels"] else: features, structure = ( data["features"].to(self.device), data["structure"].to(self.device), ) test_mask, labels = ( data["test_mask"].to(self.device), data["labels"].to(self.device), ) if model is None: model = self.best_model model = model.to(self.device) model.eval() outputs = model(features, structure) res = self.evaluator.test(labels[test_mask], outputs[test_mask]) return res