from copy import deepcopy
from datetime import datetime
from glob import glob
from logging import Logger, getLogger
from os import makedirs
from os.path import exists, join
from pathlib import Path
from shutil import rmtree
import numpy as np
from tqdm import tqdm
from denspp.offline import check_keylist_elements_any, get_path_to_project
from denspp.offline.data_call import DataFromFile, SettingsData
from denspp.offline.preprocessing import FrameWaveform
[docs]
class MergeDataset:
def __init__(
self,
pipeline,
dataloader,
settings_data: SettingsData,
concatenate_id: bool = False,
) -> None:
"""Class for handling the merging process for generating datasets from transient input signals
:param pipeline: Construct of selected pipeline for processing data
:param dataloader: Construct of used Dataloader for getting and handling data
:param settings_data: Class SettingsData for configuring the data loader
:param concatenate_id: Do concatenation of the class number with increasing id number (useful for non-biological clusters)
:return: None
"""
self._logger: Logger = getLogger(__name__)
self._dataloader = dataloader
self._settings: SettingsData = settings_data
self._pipeline = pipeline
self._check_right_pipeline()
self._path2save = get_path_to_project("temp_merge")
self._do_label_concatenation = concatenate_id
def _check_right_pipeline(self) -> None:
if not check_keylist_elements_any(
keylist=dir(self._pipeline),
elements=["run_preprocessing", "run_classifier", "settings"],
):
raise ImportError(
"Wrong pipeline is implemented. It should include 'run_preprocessing' and 'run_classifier'."
)
def _generate_folder(self) -> None:
if exists(self._path2save):
rmtree(self._path2save)
makedirs(self._path2save, exist_ok=True)
def _save_results(self, data: dict, path2folder: str, file_name: str) -> str:
file_name = f"{file_name}.npy"
path2file = join(path2folder, file_name)
np.save(path2file, data, allow_pickle=True)
self._logger.info(f"Saving file in: {file_name}")
return path2file
def _iteration_save_results(self, data: list, data_name: str) -> None:
create_time = datetime.now().strftime("%Y-%m-%d")
data_save = {"frames": data}
self._save_results(
data=data_save,
path2folder=self._path2save,
file_name=f"{create_time}_Dataset-{data_name}",
)
def _get_frames_from_labeled_dataset(self, data: DataFromFile, xpos_offset: int = 0) -> list:
self._logger.info(f"\nProcessing file: {data.data_name}")
pipeline = self._pipeline(data.fs_used, False)
frames_extracted = list()
for rawdata, xposition, label in tqdm(
zip(data.data_raw, data.evnt_xpos, data.evnt_id),
ncols=100,
desc="Progress: ",
):
xpos_scaler = pipeline.fs_ana / pipeline.fs_ana
xpos_updated = np.floor(xpos_scaler * xposition).astype("int")
result = pipeline.run_preprocessor(rawdata, xpos_updated, xpos_offset)
frame_new: FrameWaveform = result["frames"]
frame_new.label = label
frames_extracted.append(frame_new)
del frame_new
return frames_extracted
def _get_frames_from_unlabeled_dataset(self, data: DataFromFile, **kwargs) -> list:
self._logger.info(f"\nProcessing file: {data.data_name}")
pipeline = self._pipeline(data.fs_used, False)
frames_extracted = list()
for rawdata in tqdm(data.data_raw, ncols=100, desc="Progress: "):
result = pipeline.run_preprocessor(rawdata)
frame_new: FrameWaveform = result["frames"]
frames_extracted.append(frame_new)
del frame_new
return frames_extracted
[docs]
def get_frames_from_dataset(self, process_points: list = (), xpos_offset: int = 0) -> None:
"""Tool for loading datasets in order to generate one new dataset (Step 1)
:param process_points: Taking the datapoints of the selected data set to process
:param xpos_offset: Integer as position offset for shifting label position of an event (only apply if label exists)
:return: None
"""
self._generate_folder()
current_index = 0
while True:
try:
sets0 = deepcopy(self._settings)
sets0.data_point = (
current_index if not len(process_points) else process_points[current_index]
)
datahandler = self._dataloader(sets0)
datahandler.do_call()
except:
break
else:
datahandler.do_resample()
datahandler.do_cut()
data: DataFromFile = datahandler.get_data()
del datahandler
if data.label_exist:
result = self._get_frames_from_labeled_dataset(data, xpos_offset=xpos_offset)
else:
result = self._get_frames_from_unlabeled_dataset(data)
self._iteration_save_results(data=result, data_name=data.data_name)
current_index += 1
[docs]
def merge_data_from_all_iteration(self) -> str:
"""Merging extracted information from all runs into one file
:return: String with path to final file
"""
folder_content = glob(join(self._path2save, "*.npy"))
folder_content.sort()
file_name = folder_content[-1]
dataset_loaded = list()
for file in folder_content:
val = np.load(file, allow_pickle=True).item()["frames"]
dataset_loaded.append(val)
# --- Merging data from different sources
frames_waveform = list()
frames_position = list()
frames_label = list()
max_num_clusters = 0
for data_case in dataset_loaded:
for data_elec in data_case:
frames_waveform.extend(data_elec.waveform.tolist())
frames_position.extend(data_elec.xpos)
frames_label.extend(data_elec.label + max_num_clusters)
max_num_clusters += (
0 if self._do_label_concatenation or len(frames_label) == 0 else 1 + max(frames_label)
)
# --- Save output
data_merged = {
"data": np.array(frames_waveform),
"class": np.array(frames_label),
"position": np.array(frames_position),
"create_time": datetime.now().strftime("%Y-%m-%d"),
}
self._save_results(
data=data_merged,
path2folder=get_path_to_project("dataset"),
file_name=Path(file_name).stem + "_Merged",
)
return self._path2save