easygraph.nn.convs package

Subpackages

Submodules

easygraph.nn.convs.common module

class easygraph.nn.convs.common.MLP(in_channels, hidden_channels, out_channels, num_layers, dropout=0.5, normalization='bn', InputNorm=False)[source]

Bases: Module

adapted from https://github.com/CUAI/CorrectAndSmooth/blob/master/gen_models.py

forward(x)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

reset_parameters()[source]
class easygraph.nn.convs.common.MultiHeadWrapper(num_heads: int, readout: str, layer: Module, **kwargs)[source]

Bases: Module

A wrapper to apply multiple heads to a given layer.

Parameters:
  • num_heads (int) – The number of heads.

  • readout (bool) – The readout method. Can be "mean", "max", "sum", or "concat".

  • layer (nn.Module) – The layer to apply multiple heads.

  • **kwargs – The keyword arguments for the layer.

forward(**kwargs) Tensor[source]

The forward function.

Note

You must explicitly pass the keyword arguments to the layer. For example, if the layer is GATConv, you must pass X=X and g=g.

easygraph.nn.convs.pma module

Module contents