Source code for easygraph.datasets.flickr
import json
import os
import easygraph as eg
import numpy as np
import scipy.sparse as sp
from easygraph.classes.graph import Graph
from .graph_dataset_base import EasyGraphBuiltinDataset
from .utils import data_type_dict
from .utils import tensor
[docs]
class FlickrDataset(EasyGraphBuiltinDataset):
r"""Flickr dataset for node classification.
Nodes are images and edges represent social tags co-occurrence.
Node features are precomputed image embeddings. Labels indicate image categories.
Statistics:
- Nodes: 89,250
- Edges: 899,756
- Classes: 7
- Feature dim: 500
Source: GraphSAINT (https://arxiv.org/abs/1907.04931)
Parameters
----------
raw_dir : str, optional
Custom directory to download the dataset. Default: None (uses standard cache dir).
force_reload : bool, optional
Whether to re-download and reprocess. Default: False.
verbose : bool, optional
Whether to print loading progress. Default: False.
transform : callable, optional
A transform applied to the graph on access.
reorder : bool, optional
Whether to apply graph reordering for locality (requires torch). Default: False.
Examples
--------
>>> from easygraph.datasets import FlickrDataset
>>> ds = FlickrDataset(verbose=True)
>>> g = ds[0]
>>> print(g.number_of_nodes(), g.number_of_edges(), ds.num_classes)
>>> print(g.nodes[0]['feat'].shape, g.nodes[0]['label'])
"""
def __init__(
self,
raw_dir=None,
force_reload=False,
verbose=False,
transform=None,
reorder=False,
):
name = "flickr"
url = self._get_dgl_url("dataset/flickr.zip")
self._reorder = reorder
super(FlickrDataset, self).__init__(
name=name,
url=url,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose,
transform=transform,
)
[docs]
def process(self):
# Load adjacency
coo = sp.load_npz(os.path.join(self.raw_path, "adj_full.npz"))
g = eg.Graph()
g.add_edges_from(list(zip(*coo.nonzero())))
# Load features
feats = np.load(os.path.join(self.raw_path, "feats.npy"))
# Load labels
with open(os.path.join(self.raw_path, "class_map.json")) as f:
class_map = json.load(f)
labels = np.array([class_map[str(i)] for i in range(feats.shape[0])])
# Load train/val/test splits
with open(os.path.join(self.raw_path, "role.json")) as f:
role = json.load(f)
train_mask = np.zeros(feats.shape[0], dtype=bool)
train_mask[role["tr"]] = True
val_mask = np.zeros(feats.shape[0], dtype=bool)
val_mask[role["va"]] = True
test_mask = np.zeros(feats.shape[0], dtype=bool)
test_mask[role["te"]] = True
# Attach node data
for i in range(feats.shape[0]):
g.add_node(i, feat=feats[i].astype(np.float32), label=int(labels[i]))
g.graph["train_mask"] = train_mask
g.graph["val_mask"] = val_mask
g.graph["test_mask"] = test_mask
self._g = g
self._num_classes = int(labels.max() + 1)
if self.verbose:
print("Loaded Flickr dataset")
print(
f" Nodes: {g.number_of_nodes()}, Edges: {g.number_of_edges()}, Features: {feats.shape[1]}, Classes: {self._num_classes}"
)
def __getitem__(self, idx):
assert idx == 0, "FlickrDataset contains only one graph"
g = self._g
# transfer mask info
g.graph["train_mask"] = g.graph.pop("train_mask")
g.graph["val_mask"] = g.graph.pop("val_mask")
g.graph["test_mask"] = g.graph.pop("test_mask")
return self._transform(g) if self._transform else g
def __len__(self):
return 1
@property
def num_classes(self):
return self._num_classes
@staticmethod
def _get_dgl_url(path):
from .utils import _get_dgl_url
return _get_dgl_url(path)