Source code for easygraph.utils.sparse
__all__ = ["sparse_dropout"]
# if not type checking
from typing import TYPE_CHECKING
if TYPE_CHECKING:
import torch
[docs]def sparse_dropout(
sp_mat: "torch.Tensor", p: float, fill_value: float = 0.0
) -> "torch.Tensor":
import torch
r"""Dropout function for sparse matrix. This function will return a new sparse matrix with the same shape as the input sparse matrix, but with some elements dropped out.
Args:
``sp_mat`` (``torch.Tensor``): The sparse matrix with format ``torch.sparse_coo_tensor``.
``p`` (``float``): Probability of an element to be dropped.
``fill_value`` (``float``): The fill value for dropped elements. Defaults to ``0.0``.
"""
device = sp_mat.device
sp_mat = sp_mat.coalesce()
assert 0 <= p <= 1
if p == 0:
return sp_mat
p = torch.ones(sp_mat._nnz(), device=device) * p
keep_mask = torch.bernoulli(1 - p).to(device)
fill_values = torch.logical_not(keep_mask) * fill_value
new_sp_mat = torch.sparse_coo_tensor(
sp_mat._indices(),
sp_mat._values() * keep_mask + fill_values,
size=sp_mat.size(),
device=sp_mat.device,
dtype=sp_mat.dtype,
)
return new_sp_mat