Source code for easygraph.nn.convs.common

import torch
import torch.nn as nn
import torch.nn.functional as F


[docs] class MultiHeadWrapper(nn.Module): r"""A wrapper to apply multiple heads to a given layer. Parameters: ``num_heads`` (``int``): The number of heads. ``readout`` (``bool``): The readout method. Can be ``"mean"``, ``"max"``, ``"sum"``, or ``"concat"``. ``layer`` (``nn.Module``): The layer to apply multiple heads. ``**kwargs``: The keyword arguments for the layer. """ def __init__( self, num_heads: int, readout: str, layer: nn.Module, **kwargs ) -> None: super().__init__() self.layers = nn.ModuleList() for _ in range(num_heads): self.layers.append(layer(**kwargs)) self.num_heads = num_heads self.readout = readout
[docs] def forward(self, **kwargs) -> torch.Tensor: r"""The forward function. .. note:: You must explicitly pass the keyword arguments to the layer. For example, if the layer is ``GATConv``, you must pass ``X=X`` and ``g=g``. """ if self.readout == "concat": return torch.cat([layer(**kwargs) for layer in self.layers], dim=-1) else: outs = torch.stack([layer(**kwargs) for layer in self.layers]) if self.readout == "mean": return outs.mean(dim=0) elif self.readout == "max": return outs.max(dim=0)[0] elif self.readout == "sum": return outs.sum(dim=0) else: raise ValueError("Unknown readout type")
[docs] class MLP(nn.Module): """adapted from https://github.com/CUAI/CorrectAndSmooth/blob/master/gen_models.py """ def __init__( self, in_channels, hidden_channels, out_channels, num_layers, dropout=0.5, normalization="bn", InputNorm=False, ): super(MLP, self).__init__() self.lins = nn.ModuleList() self.normalizations = nn.ModuleList() self.InputNorm = InputNorm assert normalization in ["bn", "ln", "None"] if normalization == "bn": if num_layers == 1: # just linear layer i.e. logistic regression if InputNorm: self.normalizations.append(nn.BatchNorm1d(in_channels)) else: self.normalizations.append(nn.Identity()) self.lins.append(nn.Linear(in_channels, out_channels)) else: if InputNorm: self.normalizations.append(nn.BatchNorm1d(in_channels)) else: self.normalizations.append(nn.Identity()) self.lins.append(nn.Linear(in_channels, hidden_channels)) self.normalizations.append(nn.BatchNorm1d(hidden_channels)) for _ in range(num_layers - 2): self.lins.append(nn.Linear(hidden_channels, hidden_channels)) self.normalizations.append(nn.BatchNorm1d(hidden_channels)) self.lins.append(nn.Linear(hidden_channels, out_channels)) elif normalization == "ln": if num_layers == 1: # just linear layer i.e. logistic regression if InputNorm: self.normalizations.append(nn.LayerNorm(in_channels)) else: self.normalizations.append(nn.Identity()) self.lins.append(nn.Linear(in_channels, out_channels)) else: if InputNorm: self.normalizations.append(nn.LayerNorm(in_channels)) else: self.normalizations.append(nn.Identity()) self.lins.append(nn.Linear(in_channels, hidden_channels)) self.normalizations.append(nn.LayerNorm(hidden_channels)) for _ in range(num_layers - 2): self.lins.append(nn.Linear(hidden_channels, hidden_channels)) self.normalizations.append(nn.LayerNorm(hidden_channels)) self.lins.append(nn.Linear(hidden_channels, out_channels)) else: if num_layers == 1: # just linear layer i.e. logistic regression self.normalizations.append(nn.Identity()) self.lins.append(nn.Linear(in_channels, out_channels)) else: self.normalizations.append(nn.Identity()) self.lins.append(nn.Linear(in_channels, hidden_channels)) self.normalizations.append(nn.Identity()) for _ in range(num_layers - 2): self.lins.append(nn.Linear(hidden_channels, hidden_channels)) self.normalizations.append(nn.Identity()) self.lins.append(nn.Linear(hidden_channels, out_channels)) self.dropout = dropout
[docs] def reset_parameters(self): for lin in self.lins: lin.reset_parameters() for normalization in self.normalizations: if not (normalization.__class__.__name__ is "Identity"): normalization.reset_parameters()
[docs] def forward(self, x): x = self.normalizations[0](x) for i, lin in enumerate(self.lins[:-1]): x = lin(x) x = F.relu(x, inplace=True) x = self.normalizations[i + 1](x) x = F.dropout(x, p=self.dropout, training=self.training) x = self.lins[-1](x) return x