easygraph.nn.convs.pma module#
- class easygraph.nn.convs.pma.PMA(*args: Any, **kwargs: Any)[source]#
Bases:
MessagePassing
PMA part: Note that in original PMA, we need to compute the inner product of the seed and neighbor nodes. i.e. e_ij = a(Wh_i,Wh_j), where a should be the inner product, h_i is the seed and h_j are neightbor nodes. In GAT, a(x,y) = a^T[x||y]. We use the same logic.
- aggregate(inputs, index, dim_size=None, aggr='add')[source]#
Aggregates messages from neighbors as \(\square_{j \in \mathcal{N}(i)}\).
Takes in the output of message computation as first argument and any argument which was initially passed to
propagate()
.By default, this function will delegate its call to scatter functions that support “add”, “mean” and “max” operations as specified in
__init__()
by theaggr
argument.
- forward(x, edge_index: torch_geometric.typing.Adj, size: torch_geometric.typing.Size | None = None, return_attention_weights=None)[source]#
- Parameters:
return_attention_weights (bool, optional) – If set to
True
, will additionally return the tuple(edge_index, attention_weights)
, holding the computed attention weights for each edge. (default:None
)