Source code for easygraph.nn.convs.hypergraphs.halfnlh_conv

import torch
import torch.nn as nn
import torch.nn.functional as F

from easygraph.nn.convs.common import MLP
from easygraph.nn.convs.pma import PMA
from torch_geometric.nn.conv import MessagePassing
from torch_scatter import scatter


[docs]class HalfNLHconv(MessagePassing): r"""The HalfNLHconv model proposed in `YOU ARE ALLSET: A MULTISET LEARNING FRAMEWORK FOR HYPERGRAPH NEURAL NETWORKS <https://openreview.net/pdf?id=hpBTIv2uy_E>`_ paper (ICLR 2022). Parameters: ``in_dim`` (``int``): : The dimension of input. ``hid_dim`` (``int``): : The dimension of hidden. ``out_dim`` (``int``): : The dimension of output. ``num_layers`` (``int``): : The number of layers. ``dropout`` (``float``): Dropout ratio. Defaults to 0.5. ``normalization`` (``str``): The normalization method. Defaults to ``bn`` ``InputNorm`` (``bool``): Defaults to False. ``heads`` (``int``): Defaults to 1 `attention`` (``bool``): Defaults to True """ def __init__( self, in_dim, hid_dim, out_dim, num_layers, dropout, normalization="bn", InputNorm=False, heads=1, attention=True, ): super(HalfNLHconv, self).__init__() self.attention = attention self.dropout = dropout if self.attention: self.prop = PMA(in_dim, hid_dim, out_dim, num_layers, heads=heads) else: if num_layers > 0: self.f_enc = MLP( in_dim, hid_dim, hid_dim, num_layers, dropout, normalization, InputNorm, ) self.f_dec = MLP( hid_dim, hid_dim, out_dim, num_layers, dropout, normalization, InputNorm, ) else: self.f_enc = nn.Identity() self.f_dec = nn.Identity()
[docs] def reset_parameters(self): if self.attention: self.prop.reset_parameters() else: if not (self.f_enc.__class__.__name__ is "Identity"): self.f_enc.reset_parameters() if not (self.f_dec.__class__.__name__ is "Identity"): self.f_dec.reset_parameters()
# self.bn.reset_parameters()
[docs] def forward(self, x, edge_index, norm, aggr="add"): """ input -> MLP -> Prop """ if self.attention: x = self.prop(x, edge_index) else: x = F.relu(self.f_enc(x)) x = F.dropout(x, p=self.dropout, training=self.training) x = self.propagate(edge_index, x=x, norm=norm, aggr=aggr) x = F.relu(self.f_dec(x)) return x
[docs] def message(self, x_j, norm): return norm.view(-1, 1) * x_j
[docs] def aggregate(self, inputs, index, dim_size=None, aggr="sum"): r"""Aggregates messages from neighbors as :math:`\square_{j \in \mathcal{N}(i)}`. Takes in the output of message computation as first argument and any argument which was initially passed to :meth:`propagate`. By default, this function will delegate its call to scatter functions that support "add", "mean" and "max" operations as specified in :meth:`__init__` by the :obj:`aggr` argument. """ # ipdb.set_trace() return scatter(inputs, index, dim=self.node_dim, reduce=aggr)