Source code for dbcollection.core.api.load

"""
Load API class.
"""


from __future__ import print_function

from dbcollection.core.manager import CacheManager
from dbcollection.core.loader import DataLoader

from .download import download
from .process import process

from .metadata import MetadataConstructor


[docs]def load(name, task='default', data_dir='', verbose=True): """Returns a metadata loader of a dataset. Returns a loader with the necessary functions to manage the selected dataset. Parameters ---------- name : str Name of the dataset. task : str, optional Name of the task to load. data_dir : str, optional Directory path to store the downloaded data. verbose : bool, optional Displays text information (if true). Returns ------- DataLoader Data loader class. Raises ------ Exception If dataset is not available for loading. Examples -------- Load the MNIST dataset. >>> import dbcollection as dbc >>> mnist = dbc.load('mnist') >>> print('Dataset name: ', mnist.db_name) Dataset name: mnist """ assert name, 'Must input a valid dataset name: {}'.format(name) loader = LoadAPI(name=name, task=task, data_dir=data_dir, verbose=verbose) data_loader = loader.run() return data_loader
[docs]class LoadAPI(object): """Dataset load API class. This class contains methods to correctly load a dataset's metadata as a data loader object. Parameters ---------- name : str Name of the dataset. task : str Name of the task to load. data_dir : str Directory path to store the downloaded data. verbose : bool Displays text information (if true). Attributes ---------- name : str Name of the dataset. task : str Name of the task to load. data_dir : str Directory path to store the downloaded data. verbose : bool Displays text information (if true). cache_manager : CacheManager Cache manager object. available_datasets_list : list List of available datast names for download. """ def __init__(self, name, task, data_dir, verbose): """Initialize class.""" assert isinstance(name, str), 'Must input a valid dataset name.' assert isinstance(task, str), 'Must input a valid task name.' assert isinstance(data_dir, str), 'Must input a valid directory.' assert isinstance(verbose, bool), "Must input a valid boolean for verbose." self.name = name self.data_dir = data_dir self.verbose = verbose self.cache_manager = self.get_cache_manager() self.task = self.parse_task_name(task) def get_cache_manager(self): return CacheManager()
[docs] def parse_task_name(self, task): """Validate the task name.""" db_metadata = self.get_dataset_metadata_obj(self.name) return db_metadata.parse_task_name(task)
def get_dataset_metadata_obj(self, name): return MetadataConstructor(name)
[docs] def run(self): """Main method.""" if not self.dataset_data_exists_in_cache(): if self.verbose: print('==> Dataset \'{}\' not found in cache.'.format(self.name)) print('Proceeding to download the data files...') self.download_dataset_data() if not self.dataset_task_metadata_exists_in_cache(): if self.verbose: print('==> Processed metadata not found for dataset \'{}\', task \'{}\'.' .format(self.name, self.task)) print('Proceeding to process the metadata for this task...') self.process_dataset_task_metadata() if self.verbose: print('==> Load the dataset\'s metadata.') dataset_loader = self.get_data_loader() if self.verbose: print('==> Dataset loading complete.') return dataset_loader
def dataset_data_exists_in_cache(self): return self.cache_manager.dataset.exists(self.name) def download_dataset_data(self): self.download_dataset() self.reload_cache()
[docs] def download_dataset(self): """Download the dataset to disk.""" download(name=self.name, data_dir=self.data_dir, extract_data=True, verbose=self.verbose)
def reload_cache(self): self.cache_manager.manager.reload_cache() def dataset_task_metadata_exists_in_cache(self): return self.cache_manager.task.exists(task=self.task, name=self.name) def process_dataset_task_metadata(self): self.process_dataset() self.reload_cache()
[docs] def process_dataset(self): """Process the dataset's metadata.""" process(name=self.name, task=self.task, verbose=self.verbose)
[docs] def get_data_loader(self): """Return a DataLoader object.""" data_dir_path = self.get_data_dir_path_from_cache() hdf5_filepath = self.get_hdf5_file_path_from_cache() data_loader = self.get_loader_obj(data_dir_path, hdf5_filepath) return data_loader
def get_data_dir_path_from_cache(self): dataset_metadata = self.get_dataset_metadata(self.name) return dataset_metadata["data_dir"] def get_dataset_metadata(self, name): return self.cache_manager.dataset.get(name) def get_hdf5_file_path_from_cache(self): task_metadata = self.get_task_metadata(self.name, self.task) return task_metadata["filename"] def get_task_metadata(self, name, task): return self.cache_manager.task.get(name, task) def get_loader_obj(self, data_dir, hdf5_filepath): return DataLoader(name=self.name, task=self.task, data_dir=data_dir, hdf5_filepath=hdf5_filepath)