Source code for easygraph.model.hypergraphs.dhne

import torch
import torch.nn as nn


[docs]class DHNE(nn.Module): r"""The DHNE model proposed in `Structural Deep Embedding for Hyper-Networks <https://arxiv.org/abs/1711.10146>`_ paper (AAAI 2018). Parameters: ``dim_feature`` (``int``): : feature dimension list ( len = 3) ``embedding_size`` (``int``): :The embedding dimension size ``hidden_size`` (``int``): The hidden full connected layer size. """ def __init__(self, dim_feature, embedding_size, hidden_size): super(DHNE, self).__init__() self.dim_feature = dim_feature self.embedding_size = embedding_size self.hidden_size = hidden_size self.encode0 = nn.Sequential( nn.Linear( in_features=self.dim_feature[0], out_features=self.embedding_size[0] ) ) self.encode1 = nn.Sequential( nn.Linear( in_features=self.dim_feature[1], out_features=self.embedding_size[1] ) ) self.encode2 = nn.Sequential( nn.Linear( in_features=self.dim_feature[2], out_features=self.embedding_size[2] ) ) self.decode_layer0 = nn.Linear( in_features=self.embedding_size[0], out_features=self.dim_feature[0] ) self.decode_layer1 = nn.Linear( in_features=self.embedding_size[1], out_features=self.dim_feature[1] ) self.decode_layer2 = nn.Linear( in_features=self.embedding_size[2], out_features=self.dim_feature[2] ) self.hidden_layer = nn.Linear( in_features=sum(self.embedding_size), out_features=self.hidden_size ) self.ouput_layer = nn.Linear(in_features=self.hidden_size, out_features=1)
[docs] def forward(self, input0, input1, input2): input0 = self.encode0(input0) input0 = torch.tanh(input0) decode0 = self.decode_layer0(input0) decode0 = torch.sigmoid(decode0) input1 = self.encode1(input1) input1 = torch.tanh(input1) decode1 = self.decode_layer1(input1) decode1 = torch.sigmoid(decode1) input2 = self.encode2(input2) input2 = torch.tanh(input2) decode2 = self.decode_layer2(input2) decode2 = torch.sigmoid(decode2) merged = torch.tanh(torch.cat((input0, input1, input2), dim=1)) merged = self.hidden_layer(merged) merged = self.ouput_layer(merged) merged = torch.sigmoid(merged) return [decode0, decode1, decode2, merged]