"""
Dataset's metadata loader classes.
"""
import h5py
from dbcollection.utils.string_ascii import convert_ascii_to_str
[docs]class FieldLoader(object):
"""Field metadata loader class.
This class contains several methods to fetch data from a specific
field of a set (group) in a hdf5 file. It contains useful information
about the field and also several methods to fetch data.
Parameters
----------
hdf5_field : h5py._hl.dataset.Dataset
hdf5 field object handler.
obj_id : int, optional
Position of the field in 'object_fields'.
Attributes
----------
data : h5py._hl.dataset.Dataset
hdf5 group object handler.
set : str
Name of the set.
name : str
Name of the field.
type : type
Type of the field's data.
shape : tuple
Shape of the field's data.
fillvalue : int
Value used to pad arrays when storing the data in the hdf5 file.
obj_id : int
Identifier of the field if contained in the 'object_ids' list.
"""
def __init__(self, hdf5_field, obj_id=None):
"""Initialize class."""
assert hdf5_field, 'Must input a valid hdf5 dataset.'
self.data = hdf5_field
self.hdf5_handler = hdf5_field
self._in_memory = False
self.set = self._get_set_name()
self.name = self._get_field_name()
self.shape = hdf5_field.shape
self.type = hdf5_field.dtype
self.fillvalue = hdf5_field.fillvalue
self.obj_id = obj_id
def _get_set_name(self):
hdf5_object_str = self._get_hdf5_object_str()
return hdf5_object_str[1]
def _get_field_name(self):
hdf5_object_str = self._get_hdf5_object_str()
return hdf5_object_str[-1]
def _get_hdf5_object_str(self):
return self.hdf5_handler.name.split('/')
[docs] def get(self, index=None, convert_to_str=False):
"""Retrieves data of the field from the dataset's hdf5 metadata file.
This method retrieves the i'th data from the hdf5 file. Also, it is
possible to retrieve multiple values by inserting a list/tuple of
number values as indexes.
Parameters
----------
index : int/list/tuple, optional
Index number of he field. If it is a list, returns the data
for all the value indexes of that list.
convert_to_str : bool, optional
Convert the output data into a string.
Warning: output must be of type np.uint8
Returns
-------
np.ndarray/list/str
Numpy array containing the field's data.
If convert_to_str is set to True, it returns a string
or list of strings.
Note
----
When using lists/tuples of indexes, this method sorts the list
and removes duplicate values. This is because the h5py
api requires the indexing elements to be in increasing order when
retrieving data.
"""
if index is None:
data = self._get_all_idx()
else:
data = self._get_range_idx(index)
if convert_to_str:
data = convert_ascii_to_str(data)
return data
def _get_all_idx(self):
"""Return the full data array."""
if self._in_memory:
return self.data
else:
return self.data.value
def _get_range_idx(self, idx):
"""Return a slice of the data array."""
assert idx is not None
if isinstance(idx, int):
return self.data[idx]
else:
size = len(idx)
if size > 1:
return self.data[sorted(set(idx))]
elif size == 1:
return self.data[idx[0]]
else:
return self._get_all_idx()
[docs] def size(self):
"""Size of the field.
Returns the number of the elements of the field.
Returns
-------
tuple
Returns the size of the field.
"""
return self.shape
[docs] def object_field_id(self):
"""Retrieves the index position of the field in the 'object_ids' list.
This method returns the position of the field in the 'object_ids' object.
If the field is not contained in this object, it returns a null value.
Returns
-------
int
Index of the field in the 'object_ids' list.
"""
return self.obj_id
[docs] def info(self, verbose=True):
"""Prints information about the field.
Displays information like name, size and shape of the field.
Parameters
----------
verbose : bool, optional
If true, display extra information about the field.
"""
if verbose:
if hasattr(self, 'obj_id'):
print('Field: {}, shape = {}, dtype = {}, (in \'object_ids\', position = {})'
.format(self.name, str(self.shape), str(self.type), self.obj_id))
else:
print('Field: {}, shape = {}, dtype = {}'
.format(self.name, str(self.shape), str(self.type)))
def _set_to_memory(self, is_in_memory):
"""Stores the contents of the field in a numpy array if True.
Parameters
----------
is_in_memory : bool
Move the data to memory (if True).
"""
assert isinstance(is_in_memory, bool), 'Invalid input. Must insert a boolean type.'
if is_in_memory:
self.data = self.hdf5_handler.value
else:
self.data = self.hdf5_handler
self._in_memory = is_in_memory
def _get_to_memory(self):
"""Modifies how data is accessed and stored.
Accessing data from a field can be done in two ways: memory or disk.
To enable data allocation and access from memory requires the user to
specify a boolean. If set to True, data is allocated to a numpy ndarray
and all accesses are done in memory. Otherwise, data is kept in disk and
accesses are done using the HDF5 object handler.
"""
return self._in_memory
to_memory = property(_get_to_memory, _set_to_memory)
[docs] def __getitem__(self, index):
"""
Parameters
----------
index : int
Index
Returns
-------
np.ndarray
Numpy data array.
"""
return self.data[index]
[docs] def __len__(self):
"""
Returns
-------
int
Number of samples
"""
return self.shape[0]
def __str__(self):
if self._in_memory:
s = 'FieldLoader: <numpy.ndarray "{}": shape {}, type "{}">' \
.format(self.name, self.data.shape, self.data.dtype)
else:
s = 'FieldLoader: ' + self.data.__str__()
return s
def __repr__(self):
return str(self)
[docs]class SetLoader(object):
"""Set metadata loader class.
This class contains several methods to fetch data from a specific
set (group) in a hdf5 file. It contains useful information about a
specific group and also several methods to fetch data.
Parameters
----------
hdf5_group : h5py._hl.group.Group
hdf5 group object handler.
Attributes
----------
hdf5_group : h5py._hl.group.Group
hdf5 group object handler.
set : str
Name of the set.
fields : tuple
List of all field names of the set.
object_fields : tuple
List of all field names of the set contained by the 'object_ids' list.
nelems : int
Number of rows in 'object_ids'.
"""
def __init__(self, hdf5_group):
"""Initialize class."""
assert hdf5_group, 'Must input a valid hdf5 group'
self.hdf5_group = hdf5_group
self.set = self._get_set_name()
self.object_fields = self._get_object_fields()
self.nelems = self._get_num_elements()
self._fields = self._get_field_names()
self.fields = self._load_hdf5_fields() # add all hdf5 datasets as data fields
self._fields_info = []
self._lists_info = []
def _get_set_name(self):
hdf5_object_str = self.hdf5_group.name
str_split = hdf5_object_str.split('/')
return str_split[-1]
def _get_object_fields(self):
object_fields_data = self.hdf5_group['object_fields'].value
output = convert_ascii_to_str(object_fields_data)
if type(output) == 'string':
output = (output,)
return output
def _get_field_names(self):
return tuple(self.hdf5_group.keys())
def _get_num_elements(self):
return len(self.hdf5_group['object_ids'])
def _load_hdf5_fields(self):
fields = {}
for field in self._fields:
obj_id = self._get_obj_id_field(field)
fields[field] = FieldLoader(self.hdf5_group[field], obj_id)
return fields
def _get_obj_id_field(self, field):
if field in self.object_fields:
return self.object_fields.index(field)
else:
return None
[docs] def get(self, field, index=None, convert_to_str=False):
"""Retrieves data from the dataset's hdf5 metadata file.
This method retrieves the i'th data from the hdf5 file with the
same 'field' name. Also, it is possible to retrieve multiple values
by inserting a list/tuple of number values as indexes.
Parameters
----------
field : str
Field name.
index : int/list/tuple, optional
Index number of the field. If it is a list, returns the data
for all the value indexes of that list.
convert_to_str : bool, optional
Convert the output data into a string.
Warning: output must be of type np.uint8
Returns
-------
np.ndarray/list/str
Numpy array containing the field's data.
If convert_to_str is set to True, it returns a string
or list of strings.
Raises
------
KeyError
If the field does not exist in the list.
"""
assert field, 'Must input a valid field name.'
try:
return self.fields[field].get(index=index, convert_to_str=convert_to_str)
except KeyError:
raise KeyError('\'{}\' does not exist in the \'{}\' set.'.format(field, self.set))
[docs] def object(self, index=None, convert_to_value=False):
"""Retrieves a list of all fields' indexes/values of an object composition.
Retrieves the data's ids or contents of all fields of an object.
It basically works as calling the get() method for each individual field
and then groups all values into a list w.r.t. the corresponding order of
the fields.
Parameters
----------
index : int/list/tuple, optional
Index number of the field. If it is a list, returns the data
for all the value indexes of that list. If no index is used,
it returns the entire data field array.
convert_to_value : bool, optional
If False, outputs a list of indexes. If True,
it outputs a list of arrays/values instead of indexes.
Returns
-------
list
Returns a list of indexes or, if convert_to_value is True,
a list of data arrays/values.
"""
indexes = self._get_object_indexes(index)
if convert_to_value:
indexes = self._convert(indexes.tolist())
return indexes
def _get_object_indexes(self, index):
return self.get('object_ids', index)
def _convert(self, index):
"""Retrieve data from the dataset's hdf5 metadata file in the original format.
This method fetches all indices of an object(s), and then it looks up for the
value for each field in 'object_ids' for a certain index(es), and then it
groups the fetches data into a single list.
Parameters
----------
index : list
List of indexes of data fields.
Returns
-------
List
Value/list of a field from the metadata cache file.
Raises
------
TypeError
If index is not a list of ints or a list of lists.
"""
assert index, 'Must input a valid index.'
if isinstance(index[0], int):
output = self._convert_to_value_single_object(index)
elif isinstance(index[0], list):
output = []
for idx in index:
output.append(self._convert_to_value_single_object(idx))
else:
raise TypeError("Invalid input index format.")
return output
def _convert_to_value_single_object(self, idx):
data = []
for i, field in enumerate(self.object_fields):
if idx[i] >= 0:
data.append(self.get(field, idx[i]))
else:
data.append([]) # undefined index retrieves an empty list
return data
[docs] def size(self, field='object_ids'):
"""Size of a field.
Returns the number of the elements of a field.
Parameters
----------
field : str, optional
Name of the field in the metadata file.
Returns
-------
tuple
Returns the size of the field.
Raises
------
KeyError
If field is invalid or does not exist in the fields dict.
"""
try:
return self.fields[field].shape
except KeyError:
raise KeyError('\'{}\' does not exist in the \'{}\' set.'.format(field, self.set))
[docs] def list(self):
"""List of all field names.
Returns
-------
list
List of all data fields of the dataset.
"""
return self._fields
[docs] def object_field_id(self, field):
"""Retrieves the index position of a field in the 'object_ids' list.
This method returns the position of a field in the 'object_ids' object.
If the field is not contained in this object, it returns a null value.
Parameters
----------
field : str
Name of the field in the metadata file.
Returns
-------
int
Index of the field in the 'object_ids' list.
Raises
------
KeyError
If field does not exists in the list of object fields.
"""
assert field, 'Must input a valid field.'
try:
return self.fields[field].object_field_id()
except KeyError:
raise KeyError('\'{}\' is not contained in \'object_fields\'.'.format(field))
[docs] def info(self):
"""Prints information about the data fields of a set.
Displays information of all fields available like field name,
size and shape of all sets. If a 'set_name' is provided, it
displays only the information for that specific set.
This method provides the necessary information about a data set
internals to help determine how to use/handle a specific field.
"""
print('\n> Set: {}'.format(self.set))
self._set_fields_lists_info()
self._print_info_fields()
self._print_info_lists()
def _set_fields_lists_info(self):
if any(self._fields_info):
return
for field in sorted(self.fields):
if self._is_field_a_list(field):
self._lists_info.append(self._get_list_info(field))
else:
self._fields_info.append(self._get_field_info(field))
def _is_field_a_list(self, field):
assert field
return field.startswith('list_')
def _get_list_info(self, field):
assert field
return {
"name": str(field),
"shape": 'shape = {}'.format(str(self.fields[field].shape)),
"type": 'dtype = {}'.format(str(self.fields[field].type))
}
def _get_field_info(self, field):
assert field
s_obj = ''
if field in self.object_fields:
obj_id = self.object_field_id(field)
s_obj = "(in 'object_ids', position = {})".format(obj_id)
return {
"name": str(field),
"shape": 'shape = {}'.format(str(self.fields[field].shape)),
"type": 'dtype = {}'.format(str(self.fields[field].type)),
"obj": s_obj
}
def _print_info_fields(self):
maxsize_name, maxsize_shape, maxsize_type = self._get_max_sizes_fields()
for i, info in enumerate(self._fields_info):
s_name = '{:{}}'.format(' - {}, '.format(info["name"]), maxsize_name)
s_shape = '{:{}}'.format('{}, '.format(info["shape"]), maxsize_shape)
s_obj = info["obj"]
if any(s_obj):
s_type = '{:{}}'.format('{},'.format(info["type"]), maxsize_type)
else:
s_type = '{:{}}'.format('{}'.format(info["type"]), maxsize_type)
print(s_name + s_shape + s_type + s_obj)
def _get_max_sizes_fields(self):
maxsize_name = max([len(d["name"]) for d in self._fields_info]) + 8
maxsize_shape = max([len(d["shape"]) for d in self._fields_info]) + 3
maxsize_type = max([len(d["type"]) for d in self._fields_info]) + 3
return maxsize_name, maxsize_shape, maxsize_type
def _print_info_lists(self):
if any(self._lists_info):
print('\n (Pre-ordered lists)')
maxsize_name, maxsize_shape = self._get_max_sizes_lists()
for i, info in enumerate(self._lists_info):
s_name = '{:{}}'.format(' - {}, '.format(info["name"]), maxsize_name)
s_shape = '{:{}}'.format('{}, '.format(info["shape"]), maxsize_shape)
s_type = info["type"]
print(s_name + s_shape + s_type)
def _get_max_sizes_lists(self):
maxsize_name = max([len(d["name"]) for d in self._lists_info]) + 8
maxsize_shape = max([len(d["shape"]) for d in self._lists_info]) + 3
return maxsize_name, maxsize_shape
[docs] def __len__(self):
"""
Returns
-------
int
Number of elements
"""
return self.nelems
def __str__(self):
s = 'SetLoader: set<{}>, len<{}>'.format(self.set, self.nelems)
return s
def __repr__(self):
return str(self)
[docs]class DataLoader(object):
"""Dataset metadata loader class.
This class contains several methods to fetch data from a hdf5 file
by using simple, easy to use functions for (meta)data handling.
Parameters
----------
name : str
Name of the dataset.
task : str
Name of the task.
data_dir : str
Path of the dataset's data directory on disk.
hdf5_filepath : str
Path of the metadata cache file stored on disk.
Attributes
----------
db_name : str
Name of the dataset.
task : str
Name of the task.
data_dir : str
Path of the dataset's data directory on disk.
hdf5_filepath : str
Path of the hdf5 metadata file stored on disk.
hdf5_file : h5py._hl.files.File
hdf5 file object handler.
root_path : str
Default data group of the hdf5 file.
sets : tuple
List of names of set splits (e.g. train, test, val, etc.)
object_fields : dict
Data field names for each set split.
"""
def __init__(self, name, task, data_dir, hdf5_filepath):
"""Initialize class."""
assert name, 'Must input a valid dataset name.'
assert task, 'Must input a valid task name.'
assert data_dir, 'Must input a valid path for the data directory.'
assert hdf5_filepath, 'Must input a valid path for the cache file.'
self.db_name = name
self.task = task
self.data_dir = data_dir
self.hdf5_filepath = hdf5_filepath
self.hdf5_file = self._load_hdf5_file()
self.root_path = '/'
self._sets = self._get_sets()
self.object_fields = self._get_object_fields()
self.sets = self._get_set_loaders()
def _load_hdf5_file(self):
return h5py.File(self.hdf5_filepath, 'r', libver='latest')
def _get_sets(self):
return tuple(sorted(self.hdf5_file['/'].keys()))
def _get_object_fields(self):
"""# fetch list of field names that compose the object list."""
object_fields = {}
for set_name in self._sets:
data = self.hdf5_file['/{}/object_fields'.format(set_name)].value
object_fields[set_name] = tuple(convert_ascii_to_str(data))
return object_fields
def _get_set_loaders(self):
"""Return a dictionary with list of set loaders."""
sets = {}
for set_name in self._sets:
sets[set_name] = SetLoader(self.hdf5_file[set_name])
return sets
[docs] def get(self, set_name, field, index=None, convert_to_str=False):
"""Retrieves data from the dataset's hdf5 metadata file.
This method retrieves the i'th data from the hdf5 file with the
same 'field' name. Also, it is possible to retrieve multiple values
by inserting a list/tuple of number values as indexes.
Parameters
----------
set_name : str
Name of the set.
field : str
Name of the data field.
idx : int/list/tuple, optional
Index number of the field. If it is a list, returns the data
for all the value indexes of that list.
convert_to_str : bool, optional
Convert the output data into a string.
Warning: output must be of type np.uint8
Returns
-------
np.ndarray/list/str
Numpy array containing the field's data.
If convert_to_str is set to True, it returns a string
or list of strings.
Raises
------
KeyError
If set name is not valid or does not exist.
"""
assert set_name, 'Must input a set name.'
assert field, 'Must input a field name.'
try:
return self.sets[set_name].get(field, index, convert_to_str=convert_to_str)
except KeyError:
self._raise_error_invalid_set_name(set_name)
def _raise_error_invalid_set_name(self, set_name):
raise KeyError("'{}' does not exist in the sets list: {}".format(set_name, self._sets))
[docs] def object(self, set_name, index=None, convert_to_value=False):
"""Retrieves a list of all fields' indexes/values of an object composition.
Retrieves the data's ids or contents of all fields of an object.
It basically works as calling the get() method for each individual field
and then groups all values into a list w.r.t. the corresponding order of
the fields.
Parameters
----------
set_name : str
Name of the set.
index : int/list/tuple, optional
Index number of the field. If it is a list, returns the data
for all the value indexes of that list. If no index is used,
it returns the entire data field array.
convert_to_value : bool, optional
If False, outputs a list of indexes. If True,
it outputs a list of arrays/values instead of indexes.
Returns
-------
list
List of indexes of the data fields available in 'object_fields'.
If convert_to_value is set to True, it returns a list of data
instead of indexes.
Raises
------
KeyError
If set name is not valid or does not exist.
"""
assert set_name, 'Must input a valid set name.'
try:
return self.sets[set_name].object(index, convert_to_value)
except KeyError:
self._raise_error_invalid_set_name(set_name)
[docs] def size(self, set_name=None, field='object_ids'):
"""Size of a field.
Returns the number of the elements of a field.
Parameters
----------
set_name : str, optional
Name of the set.
field : str, optional
Name of the field in the metadata file.
Returns
-------
list/dict
Returns the size of a field.
Raises
------
KeyError
If set name is not valid or does not exist.
"""
if set_name is None:
return self._get_size_all_sets(field)
else:
return self._get_size_single_set(set_name, field)
def _get_size_all_sets(self, field):
assert field
out = {}
for set_name in self.sets:
out[set_name] = self.sets[set_name].size(field)
return out
def _get_size_single_set(self, set_name, field):
assert set_name
assert field
try:
return self.sets[set_name].size(field)
except KeyError:
self._raise_error_invalid_set_name(set_name)
[docs] def list(self, set_name=None):
"""List of all field names of a set.
Parameters
----------
set_name : str, optional
Name of the set.
Returns
-------
list/dict
List of all data fields of the dataset.
Raises
------
KeyError
If set name is not valid or does not exist.
"""
if set_name is None:
return self._get_list_all_sets()
else:
return self._get_list_single_set(set_name)
def _get_list_all_sets(self):
out = {}
for set_name in self.sets:
out.update({set_name: self.sets[set_name].list()})
return out
def _get_list_single_set(self, set_name):
assert set_name
try:
return self.sets[set_name].list()
except KeyError:
self._raise_error_invalid_set_name(set_name)
[docs] def object_field_id(self, set_name, field):
"""Retrieves the index position of a field in the 'object_ids' list.
This method returns the position of a field in the 'object_ids' object.
If the field is not contained in this object, it returns a null value.
Parameters
----------
set_name : str
Name of the set.
field : str
Name of the field in the metadata file.
Returns
-------
int
Index of the field in the 'object_ids' list.
Raises
------
KeyError
If set name is not valid or does not exist.
"""
assert set_name, 'Must input a valid set name.'
assert field, 'Must input a valid field name.'
try:
return self.sets[set_name].object_field_id(field)
except KeyError:
self._raise_error_invalid_set_name(set_name)
[docs] def info(self, set_name=None):
"""Prints information about all data fields of a set.
Displays information of all fields of a set group inside the hdf5
metadata file. This information contains the name of the field, as well
as the size/shape of the data, the data type and if the field is
contained in the 'object_ids' list.
If no 'set_name' is provided, it displays information for all available
sets.
This method only shows the most useful information about a set/fields
internals, which should be enough for most users in helping to
determine how to use/handle a specific dataset with little effort.
Parameters
----------
set_name : str, optional
Name of the set.
Raises
------
KeyError
If set name is not valid or does not exist.
"""
if set_name is None:
self._print_info_all_sets()
else:
self._print_info_single_set(set_name)
def _print_info_all_sets(self):
for set_name in sorted(self.sets):
self.sets[set_name].info()
def _print_info_single_set(self, set_name):
assert set_name
try:
self.sets[set_name].info()
except KeyError:
self._raise_error_invalid_set_name(set_name)
def __len__(self):
return len(self.sets)
def __str__(self):
s = "DataLoader: {} ('{}' task)".format(self.db_name, self.task)
return s
def __repr__(self):
return str(self)