Source code for easygraph.datasets.reddit
import os
import easygraph as eg
import numpy as np
import scipy.sparse as sp
from easygraph.classes.graph import Graph
from .graph_dataset_base import EasyGraphBuiltinDataset
from .utils import data_type_dict
from .utils import download
from .utils import extract_archive
from .utils import tensor
[docs]
class RedditDataset(EasyGraphBuiltinDataset):
r"""Reddit posts graph (Sept 2014) for community (subreddit) classification.
Statistics:
- Nodes: ~232,965
- Edges: ~114 million (approx.)
- Features per node: 602
- Classes: number of subreddit communities
Data are split by post-day: first 20 days train, then validation (30%), test (rest).
Parameters
----------
self_loop : bool
Add self-loop edges if True.
raw_dir, force_reload, verbose, transform : same as EasyGraphBuiltinDataset
"""
def __init__(
self,
self_loop=False,
raw_dir=None,
force_reload=False,
verbose=True,
transform=None,
):
name = "reddit"
url = "https://data.dgl.ai/dataset/reddit.zip"
self.self_loop = self_loop
super().__init__(
name=name,
url=url,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose,
transform=transform,
)
[docs]
def process(self):
# Expect two files extracted: reddit_data.npz & reddit_graph.npz
data = np.load(os.path.join(self.raw_path, "reddit_data.npz"))
feat = data["feature"] # shape [N, 602]
labels = data["label"] # shape [N]
split = data["node_types"] # 1=train,2=val,3=test
# Load adjacency
adj = sp.load_npz(os.path.join(self.raw_path, "reddit_graph.npz"))
src, dst = adj.nonzero()
if self.self_loop:
self_loops = np.arange(adj.shape[0])
src = np.concatenate([src, self_loops])
dst = np.concatenate([dst, self_loops])
edges = list(zip(src, dst))
# Build graph
g = eg.Graph()
g.add_edges_from(edges)
# Assign node features, labels, and masks
for i in range(feat.shape[0]):
g.add_node(
i,
feat=feat[i],
label=int(labels[i]),
train_mask=(split[i] == 1),
val_mask=(split[i] == 2),
test_mask=(split[i] == 3),
)
self._g = g
self._num_classes = int(np.max(labels) + 1)
if self.verbose:
print("Loaded Reddit dataset:")
print(f" NumNodes: {g.number_of_nodes()}")
print(f" NumEdges: {g.number_of_edges()}")
print(f" NumFeats: {feat.shape[1]}")
print(f" NumClasses: {self._num_classes}")
def __getitem__(self, idx):
assert idx == 0, "RedditDataset only contains one graph"
return self._g if self.transform is None else self.transform(self._g)
def __len__(self):
return 1
@property
def num_classes(self):
return self._num_classes