import numpy as np
from logging import getLogger, Logger
from dataclasses import dataclass
from os.path import join, exists, dirname, basename
from os import makedirs
from glob import glob
from pathlib import Path
from fractions import Fraction
from scipy.signal import resample_poly
from denspp.offline import get_path_to_project_start
from denspp.offline.csv_handler import CsvHandler
from denspp.offline.data_call.owncloud_handler import OwnCloudDownloader
[docs]
@dataclass
class SettingsData:
"""Class for configuring the dataloader
Attributes:
pipeline: String with name of the pipeline to use
path: Path to data storage
data_set: String with key for used data set
data_point: Number within the dataset
t_range_sec: List of the given time range (in sec.) for cutting the data [x, y]
ch_sel: List of electrodes to use [empty=all]
fs_resample: Resampling frequency of the datapoint (== 0.0, no resampling)
do_mapping: Decision if mapping (if available) is used
is_mapping_str: Boolean if mapping input from csv file is a string (True) or integer (false)
"""
pipeline: str
path: str
data_set: str
data_case: int
data_point: int
t_range_sec: list
ch_sel: list
fs_resample: float
do_mapping: bool
is_mapping_str: bool
DefaultSettingsData = SettingsData(
pipeline='PipelineV0',
path='data',
data_set='quiroga',
data_case=0, data_point=0,
t_range_sec=[], ch_sel=[],
fs_resample=100e3,
do_mapping=True,
is_mapping_str=False
)
[docs]
class DataHandler:
"""Class with data and meta information of the used neural dataset
Attributes:
data_name (str): File name of used transient dataset
data_type (str): Type / Source of the transient dataset
fs_orig (float): Sampling rate of the original dataset [Hz]
fs_used (float): Sampling rate used in pipeline processing [Hz]
time_end (float): Float number with time duration of measurement / data [s]
data_raw (np.ndarray): Numpy array with data [shape=(num_elec (opt.), num_samples)]
electrode_id (list): List with integer values for matching data to electrode layout
mapping_used (np.ndarray): Array with 2D electrode mapping (loaded from external file)
label_exist (bool): Boolean if label with event position and ID exists
evnt_xpos (list): List with numpy arrays of event position on transient signal
evnt_id (list): List with numpy arrays of event ID for each electrode
"""
# --- Meta Information
data_name: str
data_type: str
fs_orig: float
fs_used: float=0.0
time_end: float
# --- Raw data and Electrode Mapping
data_raw: np.ndarray
electrode_id: list
mapping_used: np.ndarray
# --- GroundTruth for Event Signal Processing
label_exist: bool=False
evnt_xpos: list
evnt_id: list
[docs]
class ControllerData:
__logger: Logger
__download_handler: OwnCloudDownloader
_raw_data: DataHandler
_settings: SettingsData
_path2folder_remote: str=''
_path2folder_local: str=''
_path2mapping: str=''
def __init__(self) -> None:
"""Class for loading and manipulating the used dataset"""
self._methods_available = dir(ControllerData)
self.__logger = getLogger(__name__)
self.__fill_factor = 1.
self.__default_data_path = join(get_path_to_project_start(), 'data')
self.__config_data_selection = [self.__default_data_path, 0, 0]
self.__download_handler = OwnCloudDownloader()
@staticmethod
def _extract_func(class_obj: object) -> list:
return [method for method in dir(class_obj) if class_obj.__name__ in method]
[docs]
def do_cut(self) -> None:
"""Cutting all transient electrode signals in the given range"""
if self._raw_data.fs_used == 0.0:
self._raw_data.fs_used = self._raw_data.fs_orig
# --- Getting the positition of used time range
t_range = np.array(self._settings.t_range_sec)
assert len(self._settings.t_range_sec) in [0, 2], f"t_range should be empty or have a length of 2 (not {len(self._settings.t_range_sec)})"
if t_range.size == 2:
rawdata_in = self._raw_data.data_raw
idx0 = int(t_range[0] * self._raw_data.fs_used)
idx1 = int(t_range[1] * self._raw_data.fs_used)
# --- Cutting labeled information
if self._raw_data.label_exist:
evnt_pos_new = list()
evnt_id_new = list()
for evnt_pos, evnt_id in zip(self._raw_data.evnt_xpos, self._raw_data.evnt_id):
# Adapting new data
pos_start = np.argwhere(evnt_pos >= idx0).flatten()
idx2 = int(pos_start[0]) if pos_start.size > 0 else -1
pos_stopp = np.argwhere(evnt_pos <= idx1).flatten()
idx3 = int(pos_stopp[-1]) if pos_stopp.size > 0 else -1
if idx2 == -1 or idx3 == -1:
evnt_pos_new.append(list())
evnt_id_new.append(list())
else:
evnt_pos_new = evnt_pos[idx2:idx3] - idx0
evnt_id_new = evnt_id[idx2:idx3]
self._raw_data.evnt_xpos = evnt_pos_new
self._raw_data.evnt_id = evnt_id_new
# --- Return adapted data
self._raw_data.data_raw = rawdata_in[:, idx0:idx1] if len(rawdata_in.shape) == 2 else rawdata_in[idx0:idx1]
self._raw_data.time_end = float(self._raw_data.data_raw.shape[-1] / self._raw_data.fs_used)
self.__fill_factor = self._raw_data.data_raw.shape[-1] / rawdata_in.shape[-1]
[docs]
def do_resample(self, u_safe: float = 5e-6) -> None:
"""Do resampling of all transient signals incl. label information
:param u_safe: Voltage range as safety window for resampling
"""
desired_fs = self._settings.fs_resample
do_resampling = bool(desired_fs != self._raw_data.fs_orig) and desired_fs != 0.0
if do_resampling:
self._raw_data.fs_used = desired_fs
(p, q) = Fraction(self._raw_data.fs_used / self._raw_data.fs_orig).limit_denominator(10000).as_integer_ratio()
scaling_value = p / q
# --- Resampling the input
data_out = list()
for data_in in self._raw_data.data_raw:
u_chck = np.mean(data_in[0:10])
du = u_chck if np.abs((u_chck < u_safe) - 1) == 1 else 0.0
data_out.append(du + resample_poly(data_in - du, p, q))
# --- Resampling the labeled information
if self._raw_data.label_exist:
spike_out = list()
for evnt_pos in self._raw_data.evnt_xpos:
spike_out.append(np.array(scaling_value * evnt_pos, dtype=int))
else:
spike_out = list()
self.__logger.debug(f"Resampling done from {self._raw_data.fs_orig} to {self._raw_data.fs_used} ({100 * scaling_value:.2f} %)")
self._raw_data.data_raw = np.array(data_out, dtype=self._raw_data.data_raw.dtype)
self._raw_data.evnt_xpos = spike_out
else:
self._raw_data.fs_used = self._raw_data.fs_orig
[docs]
def get_data(self) -> DataHandler:
"""Calling the raw data with ground truth of the called data"""
return self._raw_data
def __get_data_available_local(self, path_ref: str, folder_name: str, data_type: str) -> str:
"""Function for getting the path to file from remote
:param path_ref: Part to the reference folder from remote
:param folder_name: Part to the folder name to find local
:param data_type: String with the data type to find local
:return: File name
"""
if path_ref:
# --- If data is online available - Checking if is is also local available
path2chck = join(self.__config_data_selection[0], dirname(path_ref).lstrip("/"), basename(path_ref))
self._path2folder_local = '' if not exists(path2chck) else dirname(join(*[path_seg for path_seg in Path(path2chck).parts[:2]], ''))
return '' if not exists(path2chck) else path2chck
else:
# --- Routine for find local data if online is not available
folder2search = [folder for folder in glob(join(self.__config_data_selection[0], '*')) if folder_name in folder]
if not len(folder2search):
raise FileNotFoundError(f"Folder with index {folder2search} not found (locally and remotely)!")
else:
self._path2folder_local = folder2search[0]
folder_content = glob(join(folder2search[0], data_type))
if len(folder_content) == 0:
# --- Go into next folder structure and look there
folder_structure = glob(join(folder2search[0], '*'))
folder2search = join(folder2search[0], self.__config_data_selection[1]) if type(
self.__config_data_selection[1]) == str else join(
folder_structure[self.__config_data_selection[1]])
folder_content = glob(join(folder2search[0], data_type))
folder_content.sort()
else:
# --- Look in folder_content for file
pass
# --- Getting the
file_name = join(folder2search, self.__config_data_selection[2]) if type(self.__config_data_selection[2]) == str else join(folder_content[self.__config_data_selection[2]])
if file_name:
return file_name
else:
raise FileNotFoundError(f"File is not available (locally and remotely)!")
def __get_path_available_remote(self, folder_name: str, data_type: str) -> str:
"""Function for getting the path to file from remote
:param folder_name: Part of the folder name to find remote
:param data_type: String with the data type to find remote
:return: File name
"""
# --- Try to get the path from remote
try:
overview = self.__download_handler.get_overview_folder(False)
path2folder = [s for s in overview if any(folder_name in s for xs in overview)]
except:
return ""
else:
if len(path2folder) == 0:
return ""
else:
self._path2folder_remote = path2folder[0]
folder_structure = self.__download_handler.get_overview_folder(False, path2folder[0])
if len(folder_structure):
folder_search = join(path2folder[0], self.__config_data_selection[1]) if type(self.__config_data_selection[1]) == str else folder_structure[self.__config_data_selection[1]]
folder_content = self.__download_handler.get_overview_data(False, folder_search, data_type)
else:
folder_content = self.__download_handler.get_overview_data(False, path2folder[0], data_type)
folder_content.sort()
return folder_content[self.__config_data_selection[2]]
def _prepare_access_file(self, folder_name: str, data_type: str) -> str:
"""Getting the file of the corresponding trial
:param folder_name: String with folder name where the data is located
:param data_type: String with data type of the data
:return: String with path to file
"""
used_datapath = self.__default_data_path if self._settings.path == '' else self._settings.path
self.__config_data_selection = [used_datapath, self._settings.data_case, self._settings.data_point]
path2remote = self.__get_path_available_remote(folder_name, data_type)
path2local = self.__get_data_available_local(path2remote, folder_name, data_type)
if path2local:
return path2local
elif path2remote and not path2local:
path2data = join(self._settings.path, dirname(path2remote[1:]))
path2file = join(self._settings.path, path2remote[1:])
makedirs(path2data, exist_ok=True)
self.__download_handler.download_file(use_dataset=False, file_name=path2remote, destination_download=path2file)
return path2file
else:
raise FileNotFoundError("--- File is not available. Please check! ---")
[docs]
def build_mapping(self, path2csv: str= '', index_search: str= 'Mapping_*.csv') -> None:
"""Transforming the input data to electrode array specific design
(considering electrode format and coordination)
:parm path2csv: Path to csv file with information about electrode mapping (Default: "")
:return: None
"""
if path2csv:
self._path2mapping = path2csv
else:
# --- Checking if mapping file is available
mapping_local = glob(join(self._path2folder_local, index_search)) if self._path2folder_local else []
mapping_remote = self.__download_handler.get_overview_data(
use_dataset=False,
search_folder=self._path2folder_remote,
format=index_search
) if self._path2folder_remote else []
# --- Getting the file
if len(mapping_local):
self._path2mapping = mapping_local[0]
elif len(mapping_remote):
self._path2mapping = join(self._path2folder_local, basename(mapping_remote[0]))
self.__download_handler.download_file(
use_dataset=False,
file_name=mapping_remote[0],
destination_download=self._path2mapping
)
else:
self._path2mapping = ''
# --- Generating mapping information
if self._settings.do_mapping and exists(self._path2mapping):
self.__logger.info("Apply electrode mapping")
self._raw_data.mapping_used = self._generate_electrode_mapping_from_csv()
else:
self.__logger.info("No electrode mapping is available")
def _generate_electrode_mapping_from_csv(self) -> np.ndarray:
"""Function for reading the CSV file for electrode mapping of used (non-electrodes = 0)"""
mapping_from_csv = CsvHandler(
path=dirname(self._path2mapping),
file_name=basename(self._path2mapping),
delimiter=';'
).read_data_from_csv(
include_chapter_line=False,
start_line=0,
type_load=str if self._settings.is_mapping_str else int
).tolist()
# --- Generating numpy array
mapping_dimension = [len(mapping_from_csv), len(mapping_from_csv[0])]
electrode_mapping = np.zeros((mapping_dimension[0], mapping_dimension[1]), dtype=int)
for row_idx, elec_row in enumerate(mapping_from_csv):
for col_idx, elec_id in enumerate(elec_row):
if row_idx < electrode_mapping.shape[0] and col_idx < electrode_mapping.shape[1]:
if type(elec_id) == str:
val = [idx + 1 for idx, val in enumerate(self._raw_data.electrode_id) if val == elec_id]
electrode_mapping[row_idx, col_idx] = val[0] if len(val) else 0
else:
electrode_mapping[row_idx, col_idx] = elec_id
return electrode_mapping
def _load_rawdata_into_pipeline(self, elec_type: str, file_name: str, fs_orig: float,
elec_orn: list, rawdata: np.ndarray, scale_data: float,
evnt_pos: list=(), evnt_id: list=()) -> None:
"""Function for preparing the loaded rawdata for using in pipeline process
:param elec_type: String with type description of the transient data
:param file_name: String with used file name
:param fs_orig: Sampling rate of transient signal [Hz]
:param elec_orn: List with Electrode orientation / mapping to rawdata
:param rawdata: Numpy array with raw data (shape: [num_channels, num_samples])
:param scale_data: Scaling factor for rawdata
:param evnt_pos: List with numpy arrays of the event positions (should have same length like elec_orn)
:param evnt_id: List with numpy arrays of the event ID (should have same length like elec_orn)
:return: None
"""
assert len(rawdata.shape) <= 2, "Variable rawdata must have one (num_samples, ) or two dimensions (num_channels, num_samples)"
if len(rawdata.shape) == 2:
assert rawdata.shape[0] == len(elec_orn), "Variable rawdata must have two dimensions (num_channels, num_samples)"
else:
assert len(elec_orn) == 1, "Variable rawdata has one dimension with (num_samples, ), also elec_orn should have a length of 1"
if len(evnt_pos):
assert type(evnt_pos) == list, "Variable evnt_pos must have type list"
assert len(evnt_pos) == rawdata.shape[0] or len(evnt_pos) == 0, "Length of event_pos should have same length like num_electrodes"
assert type(evnt_id) == list, "Variable evnt_id must have type list"
assert len(evnt_id) == rawdata.shape[0] or len(evnt_id) == 0, "Length of event_pos should have same length like num_electrodes"
self._raw_data = DataHandler()
# --- Including meta data
self._raw_data.data_type = elec_type
self._raw_data.data_name = basename(file_name)
self._raw_data.fs_orig = fs_orig
# --- Including raw data
self._raw_data.electrode_id = elec_orn
self._raw_data.data_raw = scale_data * rawdata if len(rawdata.shape) == 2 else np.expand_dims(scale_data * rawdata, axis=0)
self._raw_data.time_end = self._raw_data.data_raw.shape[1] / self._raw_data.fs_orig
# --- Including labels
if len(evnt_pos) and len(evnt_id):
self._raw_data.evnt_xpos = evnt_pos
self._raw_data.evnt_id = evnt_id
self._raw_data.label_exist = True
else:
self._raw_data.evnt_xpos = list()
self._raw_data.evnt_id = list()
self._raw_data.label_exist = False
[docs]
def do_call(self):
"""Loading the dataset"""
# --- Getting the function to call
warning_text = "\nPlease select key words in variable 'data_set' for calling methods to read transient data"
warning_text += "\n=========================================================================================="
for method in self._methods_available:
warning_text += f"\n\t{method}"
# --- Call the function
used_data_source_idx = -1
for idx, method in enumerate(self._methods_available):
if self._settings.data_set in method:
used_data_source_idx = idx
break
if not self._settings.data_set or used_data_source_idx == -1:
self.__logger.error(warning_text)
raise ValueError
else:
getattr(self, self._methods_available[used_data_source_idx])()