Source code for easygraph.model.hypergraphs.hwnn
import torch
import torch.nn as nn
import torch.nn.functional as F
from easygraph.nn import HWNNConv
[docs]class HWNN(nn.Module):
r"""The HWNN model proposed in `Heterogeneous Hypergraph Embedding for Graph Classification <https://arxiv.org/abs/2010.10728>`_ paper (WSDM 2021).
Parameters:
``in_channels`` (``int``): Number of input feature channels. :math:`C_{in}` is the dimension of input features.
``num_classes`` (``int``): Number of target classes for classification.
``ncount`` (``int``): Total number of nodes in the hypergraph.
``hyper_snapshot_num`` (``int``, optional): number of sementic snapshots for the given heterogeneous hypergraph.
``hid_channels`` (``int``, optional): Number of hidden units. :math:`C_{hid}` is the dimension of hidden representations. Defaults to 128.
``drop_rate`` (``float``, optional): Dropout probability for regularization. Defaults to 0.01.
"""
def __init__(
self,
in_channels: int,
num_classes: int,
ncount: int,
hyper_snapshot_num: int = 1,
hid_channels: int = 128,
drop_rate: float = 0.01,
) -> None:
super().__init__()
self.drop_rate = drop_rate
self.convolution_1 = HWNNConv(
in_channels, hid_channels, ncount, K1=3, K2=3, approx=True
)
self.convolution_2 = HWNNConv(
hid_channels, num_classes, ncount, K1=3, K2=3, approx=True
)
self.par = torch.nn.Parameter(torch.Tensor(hyper_snapshot_num))
torch.nn.init.uniform_(self.par, 0, 0.99)
[docs] def forward(self, X: torch.Tensor, hgs: list) -> torch.Tensor:
r"""The forward function.
Parameters:
``X`` (``torch.Tensor``): Input vertex feature matrix. Size :math:`(N, C_{in})`.
``hg`` (``eg.Hypergraph``): The hypergraph structure that contains :math:`N` vertices.
``hgs`` (``list`` of ``Hypergraph``): A list of hypergraph structures whcih stands for snapshots.
"""
channel = []
hyper_snapshot_num = len(hgs)
for snap_index in range(hyper_snapshot_num):
hg = hgs[snap_index]
Y = F.relu(self.convolution_1(X, hg))
Y = F.dropout(Y, self.drop_rate)
Y = self.convolution_2(Y, hg)
Y = F.log_softmax(Y, dim=1)
channel.append(Y)
X = torch.zeros_like(channel[0])
for ind in range(hyper_snapshot_num):
X = X + self.par[ind] * channel[ind]
return X