easygraph.experiments.base module#
- class easygraph.experiments.base.BaseTask(work_root: str | Path | None, data: dict, model_builder: Callable, train_builder: Callable, evaluator: BaseEvaluator, device: torch.device, structure_builder: Callable | None = None, study_name: str | None = None, overwrite: bool = True)[source]#
Bases:
objectThe base class of Auto-experiment in EasyGraph.
- Parameters:
work_root (
Optional[Union[str, Path]]) – User’s work root to store all studies.data (
dict) – The dictionary to store input data that used in the experiment.model_builder (
Callable) – The function to build a model with a fixed parametertrial.train_builder (
Callable) – The function to build a training configuration with two fixed parameterstrialandmodel.evaluator (
eg.ml_metrics.BaseEvaluator) – The EasyGraph evaluator object to evaluate performance of the model in the experiment.device (
torch.device) – The target device to run the experiment.structure_builder (
Optional[Callable]) – The function to build a structure with a fixed parametertrial. The structure can beeg.Graph,eg.DiGraph,eg.BiGraph, andeg.Hypergraph.study_name (
Optional[str]) – The name of this study. If set toNone, the study name will be generated automatically according to current time. Defaults toNone.overwrite (
bool) – The flag that whether to overwrite the existing study. Different studies are identified by thestudy_name. Defaults toTrue.
- experiment(trial: optuna.Trial)[source]#
Run the experiment for a given trial.
- Parameters:
trial (
optuna.Trial) – Theoptuna.Trialobject.
- run(max_epoch: int, num_trials: int = 1, direction: str = 'maximize')[source]#
Run experiments with automatically hyper-parameter tuning.
- Parameters:
max_epoch (
int) – The maximum number of epochs to train for each experiment.num_trials (
int) – The number of trials to run. Defaults to1.direction (
str) – The direction to optimize. Defaults to"maximize".
- test(data: dict | None = None, model: torch.nn.Module | None = None)#
Test the model.
- Parameters:
data (
dict, optional) – The input data if set toNone, the specifieddatain the initialization of the experiments will be used. Defaults toNone.model (
nn.Module, optional) – The model if set toNone, the trained best model will be used. Defaults toNone.
- abstract train(data: dict, model: torch.nn.Module, optimizer: torch.optim.Optimizer, criterion: torch.nn.Module)[source]#
Train model for one epoch.
- Parameters:
data (
dict) – The input data.model (
nn.Module) – The model.optimizer (
torch.optim.Optimizer) – The model optimizer.criterion (
nn.Module) – The loss function.
- validate(data: dict, model: torch.nn.Module)#
Validate the model.
- Parameters:
data (
dict) – The input data.model (
nn.Module) – The model.