Source code for easygraph.nn.regularization

from typing import List

import torch
import torch.nn as nn


__all__ = ["EmbeddingRegularization"]


[docs] class EmbeddingRegularization(nn.Module): r"""Regularization function for embeddings. Parameters: ``p`` (``int``): The power to use in the regularization. Defaults to ``2``. ``weight_decay`` (``float``): The weight of the regularization. Defaults to ``1e-4``. """ def __init__(self, p: int = 2, weight_decay: float = 1e-4): super().__init__() self.p = p self.weight_decay = weight_decay
[docs] def forward(self, *embs: List[torch.Tensor]): r"""The forward function. Parameters: ``embs`` (``List[torch.Tensor]``): The input embeddings. """ loss = 0 for emb in embs: loss += 1 / self.p * emb.pow(self.p).sum(dim=1).mean() return self.weight_decay * loss