import torch
import torch.nn as nn
from easygraph.nn import MultiHeadWrapper
from easygraph.nn import UniGATConv
from easygraph.nn import UniGCNConv
from easygraph.nn import UniGINConv
from easygraph.nn import UniSAGEConv
__all__ = [
"UniGCN",
"UniGAT",
"UniSAGE",
"UniGIN",
]
[docs]class UniGCN(nn.Module):
r"""The UniGCN model proposed in `UniGNN: a Unified Framework for Graph and Hypergraph Neural Networks <https://arxiv.org/pdf/2105.00956.pdf>`_ paper (IJCAI 2021).
Args:
``in_channels`` (``int``): :math:`C_{in}` is the number of input channels.
``hid_channels`` (``int``): :math:`C_{hid}` is the number of hidden channels.
``num_classes`` (``int``): The Number of class of the classification task.
``use_bn`` (``bool``): If set to ``True``, use batch normalization. Defaults to ``False``.
``drop_rate`` (``float``, optional): Dropout ratio. Defaults to ``0.5``.
"""
def __init__(
self,
in_channels: int,
hid_channels: int,
num_classes: int,
use_bn: bool = False,
drop_rate: float = 0.5,
) -> None:
super().__init__()
self.layers = nn.ModuleList()
self.layers.append(
UniGCNConv(in_channels, hid_channels, use_bn=use_bn, drop_rate=drop_rate)
)
self.layers.append(
UniGCNConv(hid_channels, num_classes, use_bn=use_bn, is_last=True)
)
[docs] def forward(self, X: torch.Tensor, hg: "eg.Hypergraph") -> torch.Tensor:
r"""The forward function.
Args:
``X`` (``torch.Tensor``): Input vertex feature matrix. Size :math:`(N, C_{in})`.
``hg`` (``eg.Hypergraph``): The hypergraph structure that contains :math:`N` vertices.
"""
for layer in self.layers:
X = layer(X, hg)
return X
[docs]class UniGAT(nn.Module):
r"""The UniGAT model proposed in `UniGNN: a Unified Framework for Graph and Hypergraph Neural Networks <https://arxiv.org/pdf/2105.00956.pdf>`_ paper (IJCAI 2021).
Args:
``in_channels`` (``int``): :math:`C_{in}` is the number of input channels.
``hid_channels`` (``int``): :math:`C_{hid}` is the number of hidden channels.
``num_classes`` (``int``): The Number of class of the classification task.
``num_heads`` (``int``): The Number of attention head in each layer.
``use_bn`` (``bool``): If set to ``True``, use batch normalization. Defaults to ``False``.
``drop_rate`` (``float``): The dropout probability. Defaults to ``0.5``.
``atten_neg_slope`` (``float``): Hyper-parameter of the ``LeakyReLU`` activation of edge attention. Defaults to 0.2.
"""
def __init__(
self,
in_channels: int,
hid_channels: int,
num_classes: int,
num_heads: int,
use_bn: bool = False,
drop_rate: float = 0.5,
atten_neg_slope: float = 0.2,
) -> None:
super().__init__()
self.drop_layer = nn.Dropout(drop_rate)
self.multi_head_layer = MultiHeadWrapper(
num_heads,
"concat",
UniGATConv,
in_channels=in_channels,
out_channels=hid_channels,
use_bn=use_bn,
drop_rate=drop_rate,
atten_neg_slope=atten_neg_slope,
)
# The original implementation has applied activation layer after the final layer.
# Thus, we donot set ``is_last`` to ``True``.
self.out_layer = UniGATConv(
hid_channels * num_heads,
num_classes,
use_bn=use_bn,
drop_rate=drop_rate,
atten_neg_slope=atten_neg_slope,
is_last=False,
)
[docs] def forward(self, X: torch.Tensor, hg: "eg.Hypergraph") -> torch.Tensor:
r"""The forward function.
Args:
``X`` (``torch.Tensor``): Input vertex feature matrix. Size :math:`(N, C_{in})`.
``hg`` (``eg.Hypergraph``): The hypergraph structure that contains :math:`N` vertices.
"""
X = self.drop_layer(X)
X = self.multi_head_layer(X=X, hg=hg)
X = self.drop_layer(X)
X = self.out_layer(X, hg)
return X
[docs]class UniSAGE(nn.Module):
r"""The UniSAGE model proposed in `UniGNN: a Unified Framework for Graph and Hypergraph Neural Networks <https://arxiv.org/pdf/2105.00956.pdf>`_ paper (IJCAI 2021).
Args:
``in_channels`` (``int``): :math:`C_{in}` is the number of input channels.
``hid_channels`` (``int``): :math:`C_{hid}` is the number of hidden channels.
``num_classes`` (``int``): The Number of class of the classification task.
``use_bn`` (``bool``): If set to ``True``, use batch normalization. Defaults to ``False``.
``drop_rate`` (``float``, optional): Dropout ratio. Defaults to ``0.5``.
"""
def __init__(
self,
in_channels: int,
hid_channels: int,
num_classes: int,
use_bn: bool = False,
drop_rate: float = 0.5,
) -> None:
super().__init__()
self.layers = nn.ModuleList()
self.layers.append(
UniSAGEConv(in_channels, hid_channels, use_bn=use_bn, drop_rate=drop_rate)
)
self.layers.append(
UniSAGEConv(hid_channels, num_classes, use_bn=use_bn, is_last=True)
)
[docs] def forward(self, X: torch.Tensor, hg: "eg.Hypergraph") -> torch.Tensor:
r"""The forward function.
Args:
``X`` (``torch.Tensor``): Input vertex feature matrix. Size :math:`(N, C_{in})`.
``hg`` (``eg.Hypergraph``): The hypergraph structure that contains :math:`N` vertices.
"""
for layer in self.layers:
X = layer(X, hg)
return X
[docs]class UniGIN(nn.Module):
r"""The UniGIN model proposed in `UniGNN: a Unified Framework for Graph and Hypergraph Neural Networks <https://arxiv.org/pdf/2105.00956.pdf>`_ paper (IJCAI 2021).
Args:
``in_channels`` (``int``): :math:`C_{in}` is the number of input channels.
``hid_channels`` (``int``): :math:`C_{hid}` is the number of hidden channels.
``num_classes`` (``int``): The Number of class of the classification task.
``eps`` (``float``): The epsilon value. Defaults to ``0.0``.
``train_eps`` (``bool``): If set to ``True``, the epsilon value will be trainable. Defaults to ``False``.
``use_bn`` (``bool``): If set to ``True``, use batch normalization. Defaults to ``False``.
``drop_rate`` (``float``, optional): Dropout ratio. Defaults to ``0.5``.
"""
def __init__(
self,
in_channels: int,
hid_channels: int,
num_classes: int,
eps: float = 0.0,
train_eps: bool = False,
use_bn: bool = False,
drop_rate: float = 0.5,
) -> None:
super().__init__()
self.layers = nn.ModuleList()
self.layers.append(
UniGINConv(
in_channels,
hid_channels,
eps=eps,
train_eps=train_eps,
use_bn=use_bn,
drop_rate=drop_rate,
)
)
self.layers.append(
UniGINConv(
hid_channels,
num_classes,
eps=eps,
train_eps=train_eps,
use_bn=use_bn,
is_last=True,
)
)
[docs] def forward(self, X: torch.Tensor, hg: "eg.Hypergraph") -> torch.Tensor:
r"""The forward function.
Args:
``X`` (``torch.Tensor``): Input vertex feature matrix. Size :math:`(N, C_{in})`.
``hg`` (``eg.Hypergraph``): The hypergraph structure that contains :math:`N` vertices.
"""
for layer in self.layers:
X = layer(X, hg)
return X