Source code for easygraph.datasets.hypergraph.hypergraph_dataset_base

from pathlib import Path
from typing import Any
from typing import Dict
from typing import List

from easygraph.datapipe import compose_pipes
from easygraph.datasets.hypergraph._global import DATASETS_ROOT
from easygraph.datasets.hypergraph._global import REMOTE_DATASETS_ROOT
from easygraph.datasets.utils import download_and_check


[docs]class BaseData: r"""The Base Class of all datasets. :: self._content = { 'item': { 'upon': [ {'filename': 'part1.pkl', 'md5': 'xxxxx',}, {'filename': 'part2.pkl', 'md5': 'xxxxx',}, ], 'loader': loader_function, 'preprocess': [datapipe1, datapipe2], }, ... } """ def __init__(self, name: str, data_root=None): # configure the data local/remote root self.name = name if data_root is None: self.data_root = DATASETS_ROOT / name else: self.data_root = Path(data_root) / name self.remote_root = REMOTE_DATASETS_ROOT + name + "/" # init self._content = {} self._raw = {} def __repr__(self) -> str: return ( f"This is {self.name} dataset:\n" + "\n".join(f" -> {k}" for k in self.content) + "\nPlease try `data['name']` to get the specified data." ) @property def content(self): r"""Return the content of the dataset.""" return list(self._content.keys())
[docs] def needs_to_load(self, item_name: str) -> bool: r"""Return whether the ``item_name`` of the dataset needs to be loaded. Args: ``item_name`` (``str``): The name of the item in the dataset. """ assert item_name in self.content, f"{item_name} is not provided in the Data" return ( isinstance(self._content[item_name], dict) and "upon" in self._content[item_name] and "loader" in self._content[item_name] )
def __getitem__(self, key: str) -> Any: if self.needs_to_load(key): cur_cfg = self._content[key] if cur_cfg.get("cache", None) is None: # get raw data item = self.raw(key) # preprocess and cache pipes = cur_cfg.get("preprocess", None) if pipes is not None: cur_cfg["cache"] = compose_pipes(*pipes)(item) else: cur_cfg["cache"] = item return cur_cfg["cache"] else: return self._content[key]
[docs] def raw(self, key: str) -> Any: r"""Return the ``key`` of the dataset with un-preprocessed format.""" if self.needs_to_load(key): cur_cfg = self._content[key] if self._raw.get(key, None) is None: upon = cur_cfg["upon"] if len(upon) == 0: return None self.fetch_files(cur_cfg["upon"]) file_path_list = [ self.data_root / u["filename"] for u in cur_cfg["upon"] ] if len(file_path_list) == 1: self._raw[key] = cur_cfg["loader"](file_path_list[0]) else: # here, you should implement a multi-file loader self._raw[key] = cur_cfg["loader"](file_path_list) return self._raw[key] else: return self._content[key]
[docs] def fetch_files(self, files: List[Dict[str, str]]): r"""Download and check the files if they are not exist. Args: ``files`` (``List[Dict[str, str]]``): The files to download, each element in the list is a dict with at lease two keys: ``filename`` and ``md5``. If extra key ``bk_url`` is provided, it will be used to download the file from the backup url. """ for file in files: cur_filename = file["filename"] cur_url = file.get("bk_url", None) if cur_url is None: cur_url = self.remote_root + cur_filename download_and_check(cur_url, self.data_root / cur_filename, file["md5"])