from collections import Counter
import torch
import torch.nn as nn
import torch.nn.functional as F
from easygraph.nn.convs.common import MLP
from easygraph.nn.convs.hypergraphs.halfnlh_conv import HalfNLHconv
from torch.nn import Linear
__all__ = ["SetGNN"]
[docs]class SetGNN(nn.Module):
    r"""The SetGNN model proposed in `YOU ARE ALLSET: A MULTISET LEARNING FRAMEWORK FOR HYPERGRAPH NEURAL NETWORKS <https://openreview.net/pdf?id=hpBTIv2uy_E>`_ paper (ICLR 2022).
    Parameters:
        ``num_features`` (``int``): : The dimension of node features.
        ``num_classes`` (``int``): The Number of class of the classification task.
        ``Classifier_hidden`` (``int``): Decoder hidden units.
        ``Classifier_num_layers`` (``int``): Layers of decoder.
        ``MLP_hidden`` (``int``): Encoder hidden units.
        ``MLP_num_layers`` (``int``): Layers of encoder.
         ``dropout`` (``float``, optional): Dropout ratio. Defaults to 0.5.
        ``aggregate`` (``str``): The aggregation method. Defaults to ``add``
        ``normalization`` (``str``): The normalization method. Defaults to ``ln``
        ``deepset_input_norm`` (``bool``):  Defaults to True.
        ``heads`` (``int``):  Defaults to 1
        `PMA`` (``bool``):  Defaults to True
        `GPR`` (``bool``):  Defaults to False
        `LearnMask`` (``bool``):  Defaults to False
        `norm`` (``Tensor``):  The weight for edges in bipartite graphs, correspond to data.edge_index
    """
    def __init__(
        self,
        num_features,
        num_classes,
        Classifier_hidden=64,
        Classifier_num_layers=2,
        MLP_hidden=64,
        MLP_num_layers=2,
        All_num_layers=2,
        dropout=0.5,
        aggregate="mean",
        normalization="ln",
        deepset_input_norm=True,
        heads=1,
        PMA=True,
        GPR=False,
        LearnMask=False,
        norm=None,
        self_loop=True,
    ):
        super(SetGNN, self).__init__()
        """
        args should contain the following:
        V_in_dim, V_enc_hid_dim, V_dec_hid_dim, V_out_dim, V_enc_num_layers, V_dec_num_layers
        E_in_dim, E_enc_hid_dim, E_dec_hid_dim, E_out_dim, E_enc_num_layers, E_dec_num_layers
        All_num_layers,dropout
        !!! V_in_dim should be the dimension of node features
        !!! E_out_dim should be the number of classes (for classification)
        """
        #         Now set all dropout the same, but can be different
        self.All_num_layers = All_num_layers
        self.dropout = dropout
        self.aggr = aggregate
        self.NormLayer = normalization
        self.InputNorm = deepset_input_norm
        self.GPR = GPR
        self.LearnMask = LearnMask
        #         Now define V2EConvs[i], V2EConvs[i] for ith layers
        #         Currently we assume there's no hyperedge features, which means V_out_dim = E_in_dim
        #         If there's hyperedge features, concat with Vpart decoder output features [V_feat||E_feat]
        self.V2EConvs = nn.ModuleList()
        self.E2VConvs = nn.ModuleList()
        self.bnV2Es = nn.ModuleList()
        self.bnE2Vs = nn.ModuleList()
        self.edge_index = None
        self.self_loop = self_loop
        if self.LearnMask:
            self.Importance = nn.Parameter(torch.ones(norm.size()))
        if self.All_num_layers == 0:
            self.classifier = MLP(
                in_channels=num_features,
                hidden_channels=Classifier_hidden,
                out_channels=num_classes,
                num_layers=Classifier_num_layers,
                dropout=self.dropout,
                normalization=self.NormLayer,
                InputNorm=False,
            )
        else:
            self.V2EConvs.append(
                HalfNLHconv(
                    in_dim=num_features,
                    hid_dim=MLP_hidden,
                    out_dim=MLP_hidden,
                    num_layers=MLP_num_layers,
                    dropout=self.dropout,
                    normalization=self.NormLayer,
                    InputNorm=self.InputNorm,
                    heads=heads,
                    attention=PMA,
                )
            )
            self.bnV2Es.append(nn.BatchNorm1d(MLP_hidden))
            self.E2VConvs.append(
                HalfNLHconv(
                    in_dim=MLP_hidden,
                    hid_dim=MLP_hidden,
                    out_dim=MLP_hidden,
                    num_layers=MLP_num_layers,
                    dropout=self.dropout,
                    normalization=self.NormLayer,
                    InputNorm=self.InputNorm,
                    heads=heads,
                    attention=PMA,
                )
            )
            self.bnE2Vs.append(nn.BatchNorm1d(MLP_hidden))
            for _ in range(self.All_num_layers - 1):
                self.V2EConvs.append(
                    HalfNLHconv(
                        in_dim=MLP_hidden,
                        hid_dim=MLP_hidden,
                        out_dim=MLP_hidden,
                        num_layers=MLP_num_layers,
                        dropout=self.dropout,
                        normalization=self.NormLayer,
                        InputNorm=self.InputNorm,
                        heads=heads,
                        attention=PMA,
                    )
                )
                self.bnV2Es.append(nn.BatchNorm1d(MLP_hidden))
                self.E2VConvs.append(
                    HalfNLHconv(
                        in_dim=MLP_hidden,
                        hid_dim=MLP_hidden,
                        out_dim=MLP_hidden,
                        num_layers=MLP_num_layers,
                        dropout=self.dropout,
                        normalization=self.NormLayer,
                        InputNorm=self.InputNorm,
                        heads=heads,
                        attention=PMA,
                    )
                )
                self.bnE2Vs.append(nn.BatchNorm1d(MLP_hidden))
            if self.GPR:
                self.MLP = MLP(
                    in_channels=num_features,
                    hidden_channels=MLP_hidden,
                    out_channels=MLP_hidden,
                    num_layers=MLP_num_layers,
                    dropout=self.dropout,
                    normalization=self.NormLayer,
                    InputNorm=False,
                )
                self.GPRweights = Linear(self.All_num_layers + 1, 1, bias=False)
                self.classifier = MLP(
                    in_channels=MLP_hidden,
                    hidden_channels=Classifier_hidden,
                    out_channels=num_classes,
                    num_layers=Classifier_num_layers,
                    dropout=self.dropout,
                    normalization=self.NormLayer,
                    InputNorm=False,
                )
            else:
                self.classifier = MLP(
                    in_channels=MLP_hidden,
                    hidden_channels=Classifier_hidden,
                    out_channels=num_classes,
                    num_layers=Classifier_num_layers,
                    dropout=self.dropout,
                    normalization=self.NormLayer,
                    InputNorm=False,
                )
[docs]    def generate_edge_index(self, dataset, self_loop=False):
        edge_list = dataset["edge_list"]
        e_ind = 0
        edge_index = [[], []]
        for e in edge_list:
            for n in e:
                edge_index[0].append(n)
                edge_index[1].append(e_ind)
            e_ind += 1
        edge_index = torch.tensor(edge_index).type(torch.LongTensor)
        if self_loop:
            hyperedge_appear_fre = Counter(edge_index[1].numpy())
            skip_node_lst = []
            for edge in hyperedge_appear_fre:
                if hyperedge_appear_fre[edge] == 1:
                    skip_node = edge_index[0][torch.where(edge_index[1] == edge)[0]]
                    skip_node_lst.append(skip_node)
            num_nodes = dataset["num_vertices"]
            new_edge_idx = len(edge_index[1]) + 1
            new_edges = torch.zeros(
                (2, num_nodes - len(skip_node_lst)), dtype=edge_index.dtype
            )
            tmp_count = 0
            for i in range(num_nodes):
                if i not in skip_node_lst:
                    new_edges[0][tmp_count] = i
                    new_edges[1][tmp_count] = new_edge_idx
                    new_edge_idx += 1
                    tmp_count += 1
            edge_index = torch.Tensor(edge_index).type(torch.LongTensor)
            edge_index = torch.cat((edge_index, new_edges), dim=1)
            _, sorted_idx = torch.sort(edge_index[0])
            edge_index = torch.Tensor(edge_index[:, sorted_idx]).type(torch.LongTensor)
        return edge_index 
[docs]    def reset_parameters(self):
        for layer in self.V2EConvs:
            layer.reset_parameters()
        for layer in self.E2VConvs:
            layer.reset_parameters()
        for layer in self.bnV2Es:
            layer.reset_parameters()
        for layer in self.bnE2Vs:
            layer.reset_parameters()
        self.classifier.reset_parameters()
        if self.GPR:
            self.MLP.reset_parameters()
            self.GPRweights.reset_parameters()
        if self.LearnMask:
            nn.init.ones_(self.Importance) 
[docs]    def forward(self, data):
        """
        The data should contain the follows
        data.x: node features
        data.edge_index: edge list (of size (2,|E|)) where data.edge_index[0] contains nodes and data.edge_index[1] contains hyperedges
        !!! Note that self loop should be assigned to a new (hyper)edge id!!!
        !!! Also note that the (hyper)edge id should start at 0 (akin to node id)
        data.norm: The weight for edges in bipartite graphs, correspond to data.edge_index
        !!! Note that we output final node representation. Loss should be defined outside.
        """
        if self.edge_index is None:
            self.edge_index = self.generate_edge_index(data, self.self_loop)
        # print("generate_edge_index:", self.edge_index.shape)
        x, edge_index = data["features"], self.edge_index
        if data["weight"] == None:
            norm = torch.ones(edge_index.size()[1])
        else:
            norm = data["weight"]
        if self.LearnMask:
            norm = self.Importance * norm
        reversed_edge_index = torch.stack([edge_index[1], edge_index[0]], dim=0)
        if self.GPR:
            xs = []
            xs.append(F.relu(self.MLP(x)))
            for i, _ in enumerate(self.V2EConvs):
                x = F.relu(self.V2EConvs[i](x, edge_index, norm, self.aggr))
                #                 x = self.bnV2Es[i](x)
                x = F.dropout(x, p=self.dropout, training=self.training)
                x = self.E2VConvs[i](x, reversed_edge_index, norm, self.aggr)
                x = F.relu(x)
                xs.append(x)
                #                 x = self.bnE2Vs[i](x)
                x = F.dropout(x, p=self.dropout, training=self.training)
            x = torch.stack(xs, dim=-1)
            x = self.GPRweights(x).squeeze()
            x = self.classifier(x)
        else:
            x = F.dropout(x, p=0.2, training=self.training)  # Input dropout
            for i, _ in enumerate(self.V2EConvs):
                x = F.relu(self.V2EConvs[i](x, edge_index, norm, self.aggr))
                #                 x = self.bnV2Es[i](x)
                x = F.dropout(x, p=self.dropout, training=self.training)
                x = F.relu(self.E2VConvs[i](x, reversed_edge_index, norm, self.aggr))
                #                 x = self.bnE2Vs[i](x)
                x = F.dropout(x, p=self.dropout, training=self.training)
            x = self.classifier(x)
        return x