easygraph.experiments.base module#

class easygraph.experiments.base.BaseTask(work_root: str | Path | None, data: dict, model_builder: Callable, train_builder: Callable, evaluator: BaseEvaluator, device: device, structure_builder: Callable | None = None, study_name: str | None = None, overwrite: bool = True)[source]#

Bases: object

The base class of Auto-experiment in EasyGraph.

Parameters:
  • work_root (Optional[Union[str, Path]]) – User’s work root to store all studies.

  • data (dict) – The dictionary to store input data that used in the experiment.

  • model_builder (Callable) – The function to build a model with a fixed parameter trial.

  • train_builder (Callable) – The function to build a training configuration with two fixed parameters trial and model.

  • evaluator (eg.ml_metrics.BaseEvaluator) – The EasyGraph evaluator object to evaluate performance of the model in the experiment.

  • device (torch.device) – The target device to run the experiment.

  • structure_builder (Optional[Callable]) – The function to build a structure with a fixed parameter trial. The structure can be eg.Graph, eg.DiGraph, eg.BiGraph, and eg.Hypergraph.

  • study_name (Optional[str]) – The name of this study. If set to None, the study name will be generated automatically according to current time. Defaults to None.

  • overwrite (bool) – The flag that whether to overwrite the existing study. Different studies are identified by the study_name. Defaults to True.

Methods

experiment(trial)

Run the experiment for a given trial.

run(max_epoch[, num_trials, direction])

Run experiments with automatically hyper-parameter tuning.

test([data, model])

Test the model.

train(data, model, optimizer, criterion)

Train model for one epoch.

validate(data, model)

Validate the model.

experiment(trial: Trial)[source]#

Run the experiment for a given trial.

Parameters:

trial (optuna.Trial) – The optuna.Trial object.

run(max_epoch: int, num_trials: int = 1, direction: str = 'maximize')[source]#

Run experiments with automatically hyper-parameter tuning.

Parameters:
  • max_epoch (int) – The maximum number of epochs to train for each experiment.

  • num_trials (int) – The number of trials to run. Defaults to 1.

  • direction (str) – The direction to optimize. Defaults to "maximize".

abstract test(data: dict | None = None, model: Module | None = None)[source]#

Test the model.

Parameters:
  • data (dict, optional) – The input data if set to None, the specified data in the initialization of the experiments will be used. Defaults to None.

  • model (nn.Module, optional) – The model if set to None, the trained best model will be used. Defaults to None.

abstract train(data: dict, model: Module, optimizer: Optimizer, criterion: Module)[source]#

Train model for one epoch.

Parameters:
  • data (dict) – The input data.

  • model (nn.Module) – The model.

  • optimizer (torch.optim.Optimizer) – The model optimizer.

  • criterion (nn.Module) – The loss function.

abstract validate(data: dict, model: Module)[source]#

Validate the model.

Parameters:
  • data (dict) – The input data.

  • model (nn.Module) – The model.