Source code for easygraph.datasets.amazon_photo

import os

import easygraph as eg
import numpy as np
import scipy.sparse as sp

from easygraph.classes.graph import Graph

from .graph_dataset_base import EasyGraphBuiltinDataset
from .utils import data_type_dict
from .utils import download
from .utils import extract_archive
from .utils import tensor


[docs] class AmazonPhotoDataset(EasyGraphBuiltinDataset): r"""Amazon Electronics Photo co-purchase graph dataset. Nodes represent products, and edges link products frequently co-purchased. Node features are bag-of-words of product reviews. The task is to classify the product category. Statistics: - Nodes: 7,650 - Edges: 119,081 - Number of Classes: 8 - Features: 745 Parameters ---------- raw_dir : str, optional Raw file directory to download/contains the input data directory. Default: None force_reload : bool, optional Whether to reload the dataset. Default: False verbose : bool, optional Whether to print out progress information. Default: True transform : callable, optional A transform that takes in a :class:`~easygraph.Graph` object and returns a transformed version. The :class:`~easygraph.Graph` object will be transformed before every access. Examples -------- >>> from easygraph.datasets import AmazonPhotoDataset >>> dataset = AmazonPhotoDataset() >>> g = dataset[0] >>> print(g.number_of_nodes()) >>> print(g.number_of_edges()) >>> print(g.nodes[0]['feat'].shape) >>> print(g.nodes[0]['label']) >>> print(dataset.num_classes) """ def __init__(self, raw_dir=None, force_reload=False, verbose=True, transform=None): name = "amazon_photo" url = "https://data.dgl.ai/dataset/amazon_co_buy_photo.zip" super(AmazonPhotoDataset, self).__init__( name=name, url=url, raw_dir=raw_dir, force_reload=force_reload, verbose=verbose, transform=transform, )
[docs] def process(self): path = os.path.join(self.raw_path, "amazon_co_buy_photo.npz") data = np.load(path) adj = sp.csr_matrix( (data["adj_data"], data["adj_indices"], data["adj_indptr"]), shape=data["adj_shape"], ) features = sp.csr_matrix( (data["attr_data"], data["attr_indices"], data["attr_indptr"]), shape=data["attr_shape"], ).todense() labels = data["labels"] g = eg.Graph() g.add_edges_from(list(zip(*adj.nonzero()))) for i in range(features.shape[0]): g.add_node(i, feat=np.array(features[i]).squeeze(), label=int(labels[i])) self._g = g self._num_classes = len(np.unique(labels)) if self.verbose: print("Finished loading AmazonPhoto dataset.") print(f" NumNodes: {g.number_of_nodes()}") print(f" NumEdges: {g.number_of_edges()}") print(f" NumFeats: {features.shape[1]}") print(f" NumClasses: {self._num_classes}")
def __getitem__(self, idx): assert idx == 0, "AmazonPhotoDataset only contains one graph" if self._g is None: raise ValueError("Graph has not been loaded or processed correctly.") return self._g if self._transform is None else self._transform(self._g) def __len__(self): return 1 @property def num_classes(self): return self._num_classes