Source code for easygraph.model.hypergraphs.dhcf

from typing import Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F

from easygraph.classes import Hypergraph


[docs] class DHCF(nn.Module): r"""The DHCF model proposed in `Dual Channel Hypergraph Collaborative Filtering <https://dl.acm.org/doi/10.1145/3394486.3403253>`_ paper (KDD 2020). .. note:: The user and item embeddings and trainable parameters are initialized with xavier_uniform distribution. Parameters: ``num_users`` (``int``): The Number of users. ``num_items`` (``int``): The Number of items. ``emb_dim`` (``int``): Embedding dimension. ``num_layers`` (``int``): The Number of layers. Defaults to ``3``. ``drop_rate`` (``float``): The dropout probability. Defaults to ``0.5``. """ def __init__( self, num_users: int, num_items: int, emb_dim: int, num_layers: int = 3, drop_rate: float = 0.5, ) -> None: super().__init__() self.num_users, self.num_items = num_users, num_items self.num_layers = num_layers self.drop_rate = drop_rate self.u_embedding = nn.Embedding(num_users, emb_dim) self.i_embedding = nn.Embedding(num_items, emb_dim) self.W_gc, self.W_bi = nn.ModuleList(), nn.ModuleList() for _ in range(self.num_layers): self.W_gc.append(nn.Linear(emb_dim, emb_dim)) self.W_bi.append(nn.Linear(emb_dim, emb_dim)) self.reset_parameters()
[docs] def reset_parameters(self): r"""Initialize learnable parameters.""" nn.init.xavier_uniform_(self.u_embedding.weight) nn.init.xavier_uniform_(self.i_embedding.weight) for W_gc, W_bi in zip(self.W_gc, self.W_bi): nn.init.xavier_uniform_(W_gc.weight) nn.init.xavier_uniform_(W_bi.weight) nn.init.constant_(W_gc.bias, 0) nn.init.constant_(W_bi.bias, 0)
[docs] def forward( self, hg_ui: Hypergraph, hg_iu: Hypergraph ) -> Tuple[torch.Tensor, torch.Tensor]: r"""The forward function. Parameters: ``hg_ui`` (``eg.Hypergraph``): The hypergraph structure that users as vertices. ``hg_iu`` (``eg.Hypergraph``): The hypergraph structure that items as vertices. """ u_embs = self.u_embedding.weight i_embs = self.i_embedding.weight all_embs = torch.cat([u_embs, i_embs], dim=0) embs_list = [all_embs] for _idx in range(self.num_layers): u_embs, i_embs = torch.split( all_embs, [self.num_users, self.num_items], dim=0 ) # ========================================================================================== # Two JHConv Layers for users and items, respectively. u_embs = hg_ui.smoothing_with_HGNN(u_embs) i_embs = hg_iu.smoothing_with_HGNN(i_embs) g_embs = torch.cat([u_embs, i_embs], dim=0) sum_embs = F.leaky_relu( self.W_gc[_idx](g_embs) + g_embs, negative_slope=0.2 ) # ========================================================================================== bi_embs = all_embs * g_embs bi_embs = F.leaky_relu(self.W_bi[_idx](bi_embs), negative_slope=0.2) all_embs = sum_embs + bi_embs all_embs = F.dropout(all_embs, p=self.drop_rate, training=self.training) all_embs = F.normalize(all_embs, p=2, dim=1) embs_list.append(all_embs) embs = torch.stack(embs_list, dim=1) embs = torch.mean(embs, dim=1) u_embs, i_embs = torch.split(embs, [self.num_users, self.num_items], dim=0) return u_embs, i_embs