easygraph.nn.regularization module#

class easygraph.nn.regularization.EmbeddingRegularization(*args: Any, **kwargs: Any)[source]#

Bases: Module

Regularization function for embeddings.

Parameters:
  • p (int) – The power to use in the regularization. Defaults to 2.

  • weight_decay (float) – The weight of the regularization. Defaults to 1e-4.

forward(*embs: List[torch.Tensor])[source]#

The forward function.

Parameters:

embs (List[torch.Tensor]) – The input embeddings.