from collections import Counter
import torch
import torch.nn as nn
import torch.nn.functional as F
from easygraph.nn.convs.common import MLP
from easygraph.nn.convs.hypergraphs.halfnlh_conv import HalfNLHconv
from torch.nn import Linear
__all__ = ["SetGNN"]
[docs]class SetGNN(nn.Module):
r"""The SetGNN model proposed in `YOU ARE ALLSET: A MULTISET LEARNING FRAMEWORK FOR HYPERGRAPH NEURAL NETWORKS <https://openreview.net/pdf?id=hpBTIv2uy_E>`_ paper (ICLR 2022).
Parameters:
``num_features`` (``int``): : The dimension of node features.
``num_classes`` (``int``): The Number of class of the classification task.
``Classifier_hidden`` (``int``): Decoder hidden units.
``Classifier_num_layers`` (``int``): Layers of decoder.
``MLP_hidden`` (``int``): Encoder hidden units.
``MLP_num_layers`` (``int``): Layers of encoder.
``dropout`` (``float``, optional): Dropout ratio. Defaults to 0.5.
``aggregate`` (``str``): The aggregation method. Defaults to ``add``
``normalization`` (``str``): The normalization method. Defaults to ``ln``
``deepset_input_norm`` (``bool``): Defaults to True.
``heads`` (``int``): Defaults to 1
`PMA`` (``bool``): Defaults to True
`GPR`` (``bool``): Defaults to False
`LearnMask`` (``bool``): Defaults to False
`norm`` (``Tensor``): The weight for edges in bipartite graphs, correspond to data.edge_index
"""
def __init__(
self,
num_features,
num_classes,
Classifier_hidden=64,
Classifier_num_layers=2,
MLP_hidden=64,
MLP_num_layers=2,
All_num_layers=2,
dropout=0.5,
aggregate="mean",
normalization="ln",
deepset_input_norm=True,
heads=1,
PMA=True,
GPR=False,
LearnMask=False,
norm=None,
self_loop=True,
):
super(SetGNN, self).__init__()
"""
args should contain the following:
V_in_dim, V_enc_hid_dim, V_dec_hid_dim, V_out_dim, V_enc_num_layers, V_dec_num_layers
E_in_dim, E_enc_hid_dim, E_dec_hid_dim, E_out_dim, E_enc_num_layers, E_dec_num_layers
All_num_layers,dropout
!!! V_in_dim should be the dimension of node features
!!! E_out_dim should be the number of classes (for classification)
"""
# Now set all dropout the same, but can be different
self.All_num_layers = All_num_layers
self.dropout = dropout
self.aggr = aggregate
self.NormLayer = normalization
self.InputNorm = deepset_input_norm
self.GPR = GPR
self.LearnMask = LearnMask
# Now define V2EConvs[i], V2EConvs[i] for ith layers
# Currently we assume there's no hyperedge features, which means V_out_dim = E_in_dim
# If there's hyperedge features, concat with Vpart decoder output features [V_feat||E_feat]
self.V2EConvs = nn.ModuleList()
self.E2VConvs = nn.ModuleList()
self.bnV2Es = nn.ModuleList()
self.bnE2Vs = nn.ModuleList()
self.edge_index = None
self.self_loop = self_loop
if self.LearnMask:
self.Importance = nn.Parameter(torch.ones(norm.size()))
if self.All_num_layers == 0:
self.classifier = MLP(
in_channels=num_features,
hidden_channels=Classifier_hidden,
out_channels=num_classes,
num_layers=Classifier_num_layers,
dropout=self.dropout,
normalization=self.NormLayer,
InputNorm=False,
)
else:
self.V2EConvs.append(
HalfNLHconv(
in_dim=num_features,
hid_dim=MLP_hidden,
out_dim=MLP_hidden,
num_layers=MLP_num_layers,
dropout=self.dropout,
normalization=self.NormLayer,
InputNorm=self.InputNorm,
heads=heads,
attention=PMA,
)
)
self.bnV2Es.append(nn.BatchNorm1d(MLP_hidden))
self.E2VConvs.append(
HalfNLHconv(
in_dim=MLP_hidden,
hid_dim=MLP_hidden,
out_dim=MLP_hidden,
num_layers=MLP_num_layers,
dropout=self.dropout,
normalization=self.NormLayer,
InputNorm=self.InputNorm,
heads=heads,
attention=PMA,
)
)
self.bnE2Vs.append(nn.BatchNorm1d(MLP_hidden))
for _ in range(self.All_num_layers - 1):
self.V2EConvs.append(
HalfNLHconv(
in_dim=MLP_hidden,
hid_dim=MLP_hidden,
out_dim=MLP_hidden,
num_layers=MLP_num_layers,
dropout=self.dropout,
normalization=self.NormLayer,
InputNorm=self.InputNorm,
heads=heads,
attention=PMA,
)
)
self.bnV2Es.append(nn.BatchNorm1d(MLP_hidden))
self.E2VConvs.append(
HalfNLHconv(
in_dim=MLP_hidden,
hid_dim=MLP_hidden,
out_dim=MLP_hidden,
num_layers=MLP_num_layers,
dropout=self.dropout,
normalization=self.NormLayer,
InputNorm=self.InputNorm,
heads=heads,
attention=PMA,
)
)
self.bnE2Vs.append(nn.BatchNorm1d(MLP_hidden))
if self.GPR:
self.MLP = MLP(
in_channels=num_features,
hidden_channels=MLP_hidden,
out_channels=MLP_hidden,
num_layers=MLP_num_layers,
dropout=self.dropout,
normalization=self.NormLayer,
InputNorm=False,
)
self.GPRweights = Linear(self.All_num_layers + 1, 1, bias=False)
self.classifier = MLP(
in_channels=MLP_hidden,
hidden_channels=Classifier_hidden,
out_channels=num_classes,
num_layers=Classifier_num_layers,
dropout=self.dropout,
normalization=self.NormLayer,
InputNorm=False,
)
else:
self.classifier = MLP(
in_channels=MLP_hidden,
hidden_channels=Classifier_hidden,
out_channels=num_classes,
num_layers=Classifier_num_layers,
dropout=self.dropout,
normalization=self.NormLayer,
InputNorm=False,
)
[docs] def generate_edge_index(self, dataset, self_loop=False):
edge_list = dataset["edge_list"]
e_ind = 0
edge_index = [[], []]
for e in edge_list:
for n in e:
edge_index[0].append(n)
edge_index[1].append(e_ind)
e_ind += 1
edge_index = torch.tensor(edge_index).type(torch.LongTensor)
if self_loop:
hyperedge_appear_fre = Counter(edge_index[1].numpy())
skip_node_lst = []
for edge in hyperedge_appear_fre:
if hyperedge_appear_fre[edge] == 1:
skip_node = edge_index[0][torch.where(edge_index[1] == edge)[0]]
skip_node_lst.append(skip_node)
num_nodes = dataset["num_vertices"]
new_edge_idx = len(edge_index[1]) + 1
new_edges = torch.zeros(
(2, num_nodes - len(skip_node_lst)), dtype=edge_index.dtype
)
tmp_count = 0
for i in range(num_nodes):
if i not in skip_node_lst:
new_edges[0][tmp_count] = i
new_edges[1][tmp_count] = new_edge_idx
new_edge_idx += 1
tmp_count += 1
edge_index = torch.Tensor(edge_index).type(torch.LongTensor)
edge_index = torch.cat((edge_index, new_edges), dim=1)
_, sorted_idx = torch.sort(edge_index[0])
edge_index = torch.Tensor(edge_index[:, sorted_idx]).type(torch.LongTensor)
return edge_index
[docs] def reset_parameters(self):
for layer in self.V2EConvs:
layer.reset_parameters()
for layer in self.E2VConvs:
layer.reset_parameters()
for layer in self.bnV2Es:
layer.reset_parameters()
for layer in self.bnE2Vs:
layer.reset_parameters()
self.classifier.reset_parameters()
if self.GPR:
self.MLP.reset_parameters()
self.GPRweights.reset_parameters()
if self.LearnMask:
nn.init.ones_(self.Importance)
[docs] def forward(self, data):
"""
The data should contain the follows
data.x: node features
data.edge_index: edge list (of size (2,|E|)) where data.edge_index[0] contains nodes and data.edge_index[1] contains hyperedges
!!! Note that self loop should be assigned to a new (hyper)edge id!!!
!!! Also note that the (hyper)edge id should start at 0 (akin to node id)
data.norm: The weight for edges in bipartite graphs, correspond to data.edge_index
!!! Note that we output final node representation. Loss should be defined outside.
"""
if self.edge_index is None:
self.edge_index = self.generate_edge_index(data, self.self_loop)
# print("generate_edge_index:", self.edge_index.shape)
x, edge_index = data["features"], self.edge_index
if data["weight"] == None:
norm = torch.ones(edge_index.size()[1])
else:
norm = data["weight"]
if self.LearnMask:
norm = self.Importance * norm
reversed_edge_index = torch.stack([edge_index[1], edge_index[0]], dim=0)
if self.GPR:
xs = []
xs.append(F.relu(self.MLP(x)))
for i, _ in enumerate(self.V2EConvs):
x = F.relu(self.V2EConvs[i](x, edge_index, norm, self.aggr))
# x = self.bnV2Es[i](x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.E2VConvs[i](x, reversed_edge_index, norm, self.aggr)
x = F.relu(x)
xs.append(x)
# x = self.bnE2Vs[i](x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = torch.stack(xs, dim=-1)
x = self.GPRweights(x).squeeze()
x = self.classifier(x)
else:
x = F.dropout(x, p=0.2, training=self.training) # Input dropout
for i, _ in enumerate(self.V2EConvs):
x = F.relu(self.V2EConvs[i](x, edge_index, norm, self.aggr))
# x = self.bnV2Es[i](x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = F.relu(self.E2VConvs[i](x, reversed_edge_index, norm, self.aggr))
# x = self.bnE2Vs[i](x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.classifier(x)
return x