Source code for easygraph.nn.convs.hypergraphs.hwnn_conv
import torch
import torch.nn as nn
from easygraph.classes import Hypergraph
[docs]class HWNNConv(nn.Module):
r"""The HWNNConv model proposed in `Heterogeneous Hypergraph Embedding for Graph Classification <https://arxiv.org/pdf/2010.10728>`_ paper (WSDM 2021).
Parameters:
``in_channels`` (``int``): :math:`C_{in}` is the number of input channels.
``out_channels`` (``int``): :math:`C_{out}` is the number of output channels.
``ncount`` (``int``): The Number of node in the hypergraph.
``K1`` (``int``): Polynomial calculation times.
``K2`` (``int``): Polynomial calculation times.
``approx`` (``bool``): Whether to use polynomial fitting
"""
def __init__(
self,
in_channels: int,
out_channels: int,
ncount: int,
K1: int = 2,
K2: int = 2,
approx: bool = False,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.K1 = K1
self.K2 = K2
self.ncount = ncount
self.approx = approx
self.W = torch.nn.Parameter(torch.Tensor(self.in_channels, self.out_channels))
self.W_d = torch.nn.Parameter(torch.Tensor(self.ncount))
self.par = torch.nn.Parameter(torch.Tensor(self.K1 + self.K2))
self.init_parameters()
[docs] def init_parameters(self):
torch.nn.init.xavier_uniform_(self.W)
torch.nn.init.uniform_(self.W_d, 0.99, 1.01)
torch.nn.init.uniform_(self.par, 0, 0.99)
[docs] def forward(self, X: torch.Tensor, hg: Hypergraph) -> torch.Tensor:
r"""The forward function.
Parameters:
X (``torch.Tensor``): Input vertex feature matrix. Size :math:`(N, C_{in})`.
hg (``eg.Hypergraph``): The hypergraph structure that contains :math:`N` vertices.
"""
if self.approx == True:
X = hg.smoothing_with_HWNN_approx(
X, self.par, self.W_d, self.K1, self.K2, self.W
)
else:
X = hg.smoothing_with_HWNN_wavelet(X, self.W_d, self.W)
return X