[docs]classEmbeddingRegularization(nn.Module):r"""Regularization function for embeddings. Parameters: ``p`` (``int``): The power to use in the regularization. Defaults to ``2``. ``weight_decay`` (``float``): The weight of the regularization. Defaults to ``1e-4``. """def__init__(self,p:int=2,weight_decay:float=1e-4):super().__init__()self.p=pself.weight_decay=weight_decay
[docs]defforward(self,*embs:List[torch.Tensor]):r"""The forward function. Parameters: ``embs`` (``List[torch.Tensor]``): The input embeddings. """loss=0forembinembs:loss+=1/self.p*emb.pow(self.p).sum(dim=1).mean()returnself.weight_decay*loss