Source code for janggu.data.data

"""Janggu specific dataset class."""

from abc import ABCMeta
from abc import abstractmethod
from abc import abstractproperty

import numpy
from keras.utils import Sequence


[docs]class Dataset: """Dataset interface. All dataset classes in janggu inherit from the Dataset class which mimics a numpy array and can be used directly with keras. Parameters ----------- name : str Name of the dataset Attributes ---------- name : str Name of the dataset shape : tuple numpy-style shape of the dataset """ __metaclass__ = ABCMeta # list of data augmentation transformations transformations = [] _name = None def __init__(self, name): self.name = name @property def name(self): """Dataset name""" return self._name @name.setter def name(self, value): if not isinstance(value, str): raise Exception('name must be a string') self._name = value @abstractmethod def __getitem__(self, idxs): # pragma: no cover pass def __len__(self): # pragma: no cover pass @abstractproperty def shape(self): # pragma: no cover """Shape of the dataset""" pass
def _data_props(data): """Extracts the shape of a provided Input-Dataset. Parameters --------- data : :class:`Dataset` or list(:class:`Dataset`) Dataset or list(Dataset). Returns ------- dict Dictionary with dataset names as keys and the corrsponding shape as value. """ if isinstance(data, Dataset): data = [data] if isinstance(data, list): dataprops = {} for datum in data: dataprops[datum.name] = {'shape': datum.shape[1:]} return dataprops elif isinstance(data, dict): return data raise Exception('inputSpace wrong argument: {}'.format(data)) class JangguSequence(Sequence): """JangguSequence class. This class is a subclass of keras.utils.Sequence. It is used to serve the fit_generator, predict_generator and evaluate_generator. """ def __init__(self, batch_size, inputs, outputs=None, sample_weights=None, shuffle=False): self.inputs = inputs self.outputs = outputs self.sample_weights = sample_weights self.batch_size = batch_size for k in inputs: xlen = len(inputs[k]) break for k in inputs: if not len(inputs[k]) == xlen: raise ValueError('Datasets contain differing number of datapoints.') for k in outputs or []: if not len(outputs[k]) == xlen: raise ValueError('Datasets contain differing number of datapoints.') self.indices = list(range(xlen)) self.shuffle = shuffle def __len__(self): return int(numpy.ceil(len(self.indices) / float(self.batch_size))) def __getitem__(self, idx): inputs = {} for k in self.inputs: inputs[k] = self.inputs[k][ self.indices[idx*self.batch_size:(idx+1)*self.batch_size]] ret = (inputs, ) if self.outputs is not None: outputs = {} for k in self.outputs: outputs[k] = self.outputs[k][ self.indices[idx*self.batch_size:(idx+1)*self.batch_size]] else: outputs = None if self.sample_weights is not None: sweight = self.sample_weights[ self.indices[idx*self.batch_size:(idx+1)*self.batch_size]] else: sweight = None ret += (outputs, sweight) return ret def on_epoch_end(self): """Stuff to do after epoch end.""" if self.shuffle: numpy.random.shuffle(self.indices)