easygraph.nn.convs package

Subpackages

Submodules

easygraph.nn.convs.common module

class easygraph.nn.convs.common.MultiHeadWrapper(num_heads: int, readout: str, layer: Module, **kwargs)[source]

Bases: Module

A wrapper to apply multiple heads to a given layer.

Parameters:
  • num_heads (int) – The number of heads.

  • readout (bool) – The readout method. Can be "mean", "max", "sum", or "concat".

  • layer (nn.Module) – The layer to apply multiple heads.

  • **kwargs – The keyword arguments for the layer.

forward(**kwargs) Tensor[source]

The forward function.

Note

You must explicitly pass the keyword arguments to the layer. For example, if the layer is GATConv, you must pass X=X and g=g.

training: bool

easygraph.nn.convs.pma module

Module contents