easygraph.nn.convs.pma module#

class easygraph.nn.convs.pma.PMA(*args: Any, **kwargs: Any)[source]#

Bases: MessagePassing

PMA part: Note that in original PMA, we need to compute the inner product of the seed and neighbor nodes. i.e. e_ij = a(Wh_i,Wh_j), where a should be the inner product, h_i is the seed and h_j are neightbor nodes. In GAT, a(x,y) = a^T[x||y]. We use the same logic.

aggregate(inputs, index, dim_size=None, aggr='add')[source]#

Aggregates messages from neighbors as \(\square_{j \in \mathcal{N}(i)}\).

Takes in the output of message computation as first argument and any argument which was initially passed to propagate().

By default, this function will delegate its call to scatter functions that support “add”, “mean” and “max” operations as specified in __init__() by the aggr argument.

forward(x, edge_index: torch_geometric.typing.Adj, size: torch_geometric.typing.Size | None = None, return_attention_weights=None)[source]#
Parameters:

return_attention_weights (bool, optional) – If set to True, will additionally return the tuple (edge_index, attention_weights), holding the computed attention weights for each edge. (default: None)

message(x_j, alpha_j, index, ptr, size_j)[source]#
reset_parameters()[source]#
easygraph.nn.convs.pma.glorot(tensor)[source]#
easygraph.nn.convs.pma.zeros(tensor)[source]#