[docs]classMultiHeadWrapper(nn.Module):r"""A wrapper to apply multiple heads to a given layer. Parameters: ``num_heads`` (``int``): The number of heads. ``readout`` (``bool``): The readout method. Can be ``"mean"``, ``"max"``, ``"sum"``, or ``"concat"``. ``layer`` (``nn.Module``): The layer to apply multiple heads. ``**kwargs``: The keyword arguments for the layer. """def__init__(self,num_heads:int,readout:str,layer:nn.Module,**kwargs)->None:super().__init__()self.layers=nn.ModuleList()for_inrange(num_heads):self.layers.append(layer(**kwargs))self.num_heads=num_headsself.readout=readout
[docs]defforward(self,**kwargs)->torch.Tensor:r"""The forward function. .. note:: You must explicitly pass the keyword arguments to the layer. For example, if the layer is ``GATConv``, you must pass ``X=X`` and ``g=g``. """ifself.readout=="concat":returntorch.cat([layer(**kwargs)forlayerinself.layers],dim=-1)else:outs=torch.stack([layer(**kwargs)forlayerinself.layers])ifself.readout=="mean":returnouts.mean(dim=0)elifself.readout=="max":returnouts.max(dim=0)[0]elifself.readout=="sum":returnouts.sum(dim=0)else:raiseValueError("Unknown readout type")
[docs]classMLP(nn.Module):"""adapted from https://github.com/CUAI/CorrectAndSmooth/blob/master/gen_models.py """def__init__(self,in_channels,hidden_channels,out_channels,num_layers,dropout=0.5,normalization="bn",InputNorm=False,):super(MLP,self).__init__()self.lins=nn.ModuleList()self.normalizations=nn.ModuleList()self.InputNorm=InputNormassertnormalizationin["bn","ln","None"]ifnormalization=="bn":ifnum_layers==1:# just linear layer i.e. logistic regressionifInputNorm:self.normalizations.append(nn.BatchNorm1d(in_channels))else:self.normalizations.append(nn.Identity())self.lins.append(nn.Linear(in_channels,out_channels))else:ifInputNorm:self.normalizations.append(nn.BatchNorm1d(in_channels))else:self.normalizations.append(nn.Identity())self.lins.append(nn.Linear(in_channels,hidden_channels))self.normalizations.append(nn.BatchNorm1d(hidden_channels))for_inrange(num_layers-2):self.lins.append(nn.Linear(hidden_channels,hidden_channels))self.normalizations.append(nn.BatchNorm1d(hidden_channels))self.lins.append(nn.Linear(hidden_channels,out_channels))elifnormalization=="ln":ifnum_layers==1:# just linear layer i.e. logistic regressionifInputNorm:self.normalizations.append(nn.LayerNorm(in_channels))else:self.normalizations.append(nn.Identity())self.lins.append(nn.Linear(in_channels,out_channels))else:ifInputNorm:self.normalizations.append(nn.LayerNorm(in_channels))else:self.normalizations.append(nn.Identity())self.lins.append(nn.Linear(in_channels,hidden_channels))self.normalizations.append(nn.LayerNorm(hidden_channels))for_inrange(num_layers-2):self.lins.append(nn.Linear(hidden_channels,hidden_channels))self.normalizations.append(nn.LayerNorm(hidden_channels))self.lins.append(nn.Linear(hidden_channels,out_channels))else:ifnum_layers==1:# just linear layer i.e. logistic regressionself.normalizations.append(nn.Identity())self.lins.append(nn.Linear(in_channels,out_channels))else:self.normalizations.append(nn.Identity())self.lins.append(nn.Linear(in_channels,hidden_channels))self.normalizations.append(nn.Identity())for_inrange(num_layers-2):self.lins.append(nn.Linear(hidden_channels,hidden_channels))self.normalizations.append(nn.Identity())self.lins.append(nn.Linear(hidden_channels,out_channels))self.dropout=dropout