Source code for easygraph.model.hypergraphs.unignn

import torch
import torch.nn as nn

from easygraph.nn import MultiHeadWrapper
from easygraph.nn import UniGATConv
from easygraph.nn import UniGCNConv
from easygraph.nn import UniGINConv
from easygraph.nn import UniSAGEConv


__all__ = [
    "UniGCN",
    "UniGAT",
    "UniSAGE",
    "UniGIN",
]


[docs]class UniGCN(nn.Module): r"""The UniGCN model proposed in `UniGNN: a Unified Framework for Graph and Hypergraph Neural Networks <https://arxiv.org/pdf/2105.00956.pdf>`_ paper (IJCAI 2021). Args: ``in_channels`` (``int``): :math:`C_{in}` is the number of input channels. ``hid_channels`` (``int``): :math:`C_{hid}` is the number of hidden channels. ``num_classes`` (``int``): The Number of class of the classification task. ``use_bn`` (``bool``): If set to ``True``, use batch normalization. Defaults to ``False``. ``drop_rate`` (``float``, optional): Dropout ratio. Defaults to ``0.5``. """ def __init__( self, in_channels: int, hid_channels: int, num_classes: int, use_bn: bool = False, drop_rate: float = 0.5, ) -> None: super().__init__() self.layers = nn.ModuleList() self.layers.append( UniGCNConv(in_channels, hid_channels, use_bn=use_bn, drop_rate=drop_rate) ) self.layers.append( UniGCNConv(hid_channels, num_classes, use_bn=use_bn, is_last=True) )
[docs] def forward(self, X: torch.Tensor, hg: "eg.Hypergraph") -> torch.Tensor: r"""The forward function. Args: ``X`` (``torch.Tensor``): Input vertex feature matrix. Size :math:`(N, C_{in})`. ``hg`` (``eg.Hypergraph``): The hypergraph structure that contains :math:`N` vertices. """ for layer in self.layers: X = layer(X, hg) return X
[docs]class UniGAT(nn.Module): r"""The UniGAT model proposed in `UniGNN: a Unified Framework for Graph and Hypergraph Neural Networks <https://arxiv.org/pdf/2105.00956.pdf>`_ paper (IJCAI 2021). Args: ``in_channels`` (``int``): :math:`C_{in}` is the number of input channels. ``hid_channels`` (``int``): :math:`C_{hid}` is the number of hidden channels. ``num_classes`` (``int``): The Number of class of the classification task. ``num_heads`` (``int``): The Number of attention head in each layer. ``use_bn`` (``bool``): If set to ``True``, use batch normalization. Defaults to ``False``. ``drop_rate`` (``float``): The dropout probability. Defaults to ``0.5``. ``atten_neg_slope`` (``float``): Hyper-parameter of the ``LeakyReLU`` activation of edge attention. Defaults to 0.2. """ def __init__( self, in_channels: int, hid_channels: int, num_classes: int, num_heads: int, use_bn: bool = False, drop_rate: float = 0.5, atten_neg_slope: float = 0.2, ) -> None: super().__init__() self.drop_layer = nn.Dropout(drop_rate) self.multi_head_layer = MultiHeadWrapper( num_heads, "concat", UniGATConv, in_channels=in_channels, out_channels=hid_channels, use_bn=use_bn, drop_rate=drop_rate, atten_neg_slope=atten_neg_slope, ) # The original implementation has applied activation layer after the final layer. # Thus, we donot set ``is_last`` to ``True``. self.out_layer = UniGATConv( hid_channels * num_heads, num_classes, use_bn=use_bn, drop_rate=drop_rate, atten_neg_slope=atten_neg_slope, is_last=False, )
[docs] def forward(self, X: torch.Tensor, hg: "eg.Hypergraph") -> torch.Tensor: r"""The forward function. Args: ``X`` (``torch.Tensor``): Input vertex feature matrix. Size :math:`(N, C_{in})`. ``hg`` (``eg.Hypergraph``): The hypergraph structure that contains :math:`N` vertices. """ X = self.drop_layer(X) X = self.multi_head_layer(X=X, hg=hg) X = self.drop_layer(X) X = self.out_layer(X, hg) return X
[docs]class UniSAGE(nn.Module): r"""The UniSAGE model proposed in `UniGNN: a Unified Framework for Graph and Hypergraph Neural Networks <https://arxiv.org/pdf/2105.00956.pdf>`_ paper (IJCAI 2021). Args: ``in_channels`` (``int``): :math:`C_{in}` is the number of input channels. ``hid_channels`` (``int``): :math:`C_{hid}` is the number of hidden channels. ``num_classes`` (``int``): The Number of class of the classification task. ``use_bn`` (``bool``): If set to ``True``, use batch normalization. Defaults to ``False``. ``drop_rate`` (``float``, optional): Dropout ratio. Defaults to ``0.5``. """ def __init__( self, in_channels: int, hid_channels: int, num_classes: int, use_bn: bool = False, drop_rate: float = 0.5, ) -> None: super().__init__() self.layers = nn.ModuleList() self.layers.append( UniSAGEConv(in_channels, hid_channels, use_bn=use_bn, drop_rate=drop_rate) ) self.layers.append( UniSAGEConv(hid_channels, num_classes, use_bn=use_bn, is_last=True) )
[docs] def forward(self, X: torch.Tensor, hg: "eg.Hypergraph") -> torch.Tensor: r"""The forward function. Args: ``X`` (``torch.Tensor``): Input vertex feature matrix. Size :math:`(N, C_{in})`. ``hg`` (``eg.Hypergraph``): The hypergraph structure that contains :math:`N` vertices. """ for layer in self.layers: X = layer(X, hg) return X
[docs]class UniGIN(nn.Module): r"""The UniGIN model proposed in `UniGNN: a Unified Framework for Graph and Hypergraph Neural Networks <https://arxiv.org/pdf/2105.00956.pdf>`_ paper (IJCAI 2021). Args: ``in_channels`` (``int``): :math:`C_{in}` is the number of input channels. ``hid_channels`` (``int``): :math:`C_{hid}` is the number of hidden channels. ``num_classes`` (``int``): The Number of class of the classification task. ``eps`` (``float``): The epsilon value. Defaults to ``0.0``. ``train_eps`` (``bool``): If set to ``True``, the epsilon value will be trainable. Defaults to ``False``. ``use_bn`` (``bool``): If set to ``True``, use batch normalization. Defaults to ``False``. ``drop_rate`` (``float``, optional): Dropout ratio. Defaults to ``0.5``. """ def __init__( self, in_channels: int, hid_channels: int, num_classes: int, eps: float = 0.0, train_eps: bool = False, use_bn: bool = False, drop_rate: float = 0.5, ) -> None: super().__init__() self.layers = nn.ModuleList() self.layers.append( UniGINConv( in_channels, hid_channels, eps=eps, train_eps=train_eps, use_bn=use_bn, drop_rate=drop_rate, ) ) self.layers.append( UniGINConv( hid_channels, num_classes, eps=eps, train_eps=train_eps, use_bn=use_bn, is_last=True, ) )
[docs] def forward(self, X: torch.Tensor, hg: "eg.Hypergraph") -> torch.Tensor: r"""The forward function. Args: ``X`` (``torch.Tensor``): Input vertex feature matrix. Size :math:`(N, C_{in})`. ``hg`` (``eg.Hypergraph``): The hypergraph structure that contains :math:`N` vertices. """ for layer in self.layers: X = layer(X, hg) return X