Source code for denspp.offline.dnn.dataset.spike_detection
import numpy as np
from torch import is_tensor
from torch.utils.data import Dataset, DataLoader
[docs]
class DatasetSDA(Dataset):
"""Dataset Preparator for training Spike Detection Classification with Neural Network"""
def __init__(self, frame: np.ndarray, sda: np.ndarray, threshold: int):
self.__frame_slice = np.array(frame, dtype=np.float32)
self.__sda_class = np.array(sda, dtype=bool)
self.__sda_thr = threshold
[docs]
def __len__(self):
return self.__frame_slice.shape[0]
[docs]
def __getitem__(self, idx):
if is_tensor(idx):
idx = idx.tolist()
decision = 0 if np.sum(self.__sda_class[idx]) < self.__sda_thr else 1
return {'in': self.__frame_slice[idx], 'sda': self.__sda_class[idx],
'out': np.array(decision, dtype=np.uint8)}
@property
def get_dictionary(self) -> list:
"""Getting the dictionary of labeled inputs"""
return ['Non-Spike', 'Spike']
@property
def get_topology_type(self) -> str:
"""Getting the information of used Autoencoder topology"""
return 'Spike Detection Algorithm'
@property
def get_cluster_num(self) -> int:
""""""
return int(np.unique(self.__sda_class).size)
[docs]
def prepare_training(rawdata: dict, threshold: int) -> DatasetSDA:
"""Preparing datasets incl. augmentation for spike-detection-based training (without pre-processing)"""
frames_in = rawdata["data"]
frames_cl = rawdata["label"]
check = np.unique(frames_cl, return_counts=True)
print(f"... for training are {frames_in.shape[0]} frames with each {frames_in.shape[1]} points available")
print(f"... used data points for training: class = {check[0]} and num = {check[1]}")
return DatasetSDA(
frame=frames_in,
sda=frames_cl,
threshold=threshold
)