from dataclasses import dataclass
from logging import Logger, getLogger
import numpy as np
from scipy.signal import iirfilter, lfilter
from .frame_generator import FrameGenerator, FrameWaveform, SettingsFrame
[docs]
@dataclass
class SettingsSDA:
"""Configuration class for defining the Spike Detection Algorithm (SDA)
Attributes:
mode_sda: Applied spike detection algorithm (SDA) on transient signal [normal, Non-Linear Energy Operator (NEO) or Teager-Kaiser-Operator (dx_sda = 1 or kNEO with dx_sda > 1),
Multiresolution Teager Energy Operator (MTEO), absolute difference operator (ADO),
enhanced energy-derivation operator (eED),
amplitude slope operator (ASO, k for window size, and f_hp as additional float arg),
spike band-power estimation [Nason et al., 2020] (SBP, using f_bp with two values as additional arg)
mode_thr: String with used method for thresholding ['const': constant given value,
'abs_mean': absolute mean value, 'mad': median absolute derivation, 'mavg', moving average,
'mavg_abs': absolute mean absolute value, 'rms_norm': Root-Mean-Squared,
'rms_move': Moving RMS, 'rms_black': RMS method used in Blackrock Neurotechnology Systems,
'welford': Welford Online Algorithm for STD Calculation]
mode_align: Aligning mode of the detected spike frames [none, max, min,
ptp (Positive turning point), ntp (Negative turning point), abs-max (Absolute maximum)]
sampling_rate: Sampling rate [Hz]
dx_sda: Position difference for extracting SDA method. Configuration with length(x) == 1: with dX = 1 --> NEO, dX > 1 --> k-NEO
t_frame_length: Floating value with total window length [s]
t_frame_start: Floating value with time point for aligned position [s]
dt_offset: Time offset for the first larger spike window [neg, pos]
thr_gain: Floating value with amplification factor on SDA output
"""
mode_sda: str
mode_thr: str
mode_align: str
dx_sda: list
sampling_rate: float
t_frame_length: float
t_frame_start: float
dt_offset: float
thr_gain: float
@property
def get_integer_offset(self) -> int:
"""Getting the integer offset for negative offset in building the spike window"""
return round(self.dt_offset * self.sampling_rate)
@property
def get_integer_offset_total(self) -> int:
"""Getting the total integer offset in building the spike window"""
return 2 * self.get_integer_offset
@property
def get_integer_spike_frame(self) -> int:
"""Getting the integer for total length of a spike window"""
return round(self.t_frame_length * self.sampling_rate)
@property
def get_integer_spike_start(self) -> int:
"""Getting the integer for starting the aligned method on each spike window"""
return round(self.t_frame_start * self.sampling_rate)
DefaultSettingsSDA = SettingsSDA(
sampling_rate=20e3,
dx_sda=[1],
mode_sda="eed",
mode_thr="const",
mode_align="min",
t_frame_length=1.6e-3,
t_frame_start=0.4e-3,
dt_offset=0.1e-3,
thr_gain=1.0,
)
[docs]
class SpikeDetection:
def __init__(self, settings: SettingsSDA) -> None:
"""Class SpikeDetection for extracting Spike Waveforms from neural transient input
:param settings: Class SettingsSDA for configuring the accelerator
:return: None
"""
self._logger: Logger = getLogger(__name__)
self._settings_sda = settings
self._settings_thr = SettingsFrame(
mode_thr=self._settings_sda.mode_thr,
mode_align=self._settings_sda.mode_align,
sampling_rate=self._settings_sda.sampling_rate,
window_sec=self._settings_sda.t_frame_length,
offset_sec=self._settings_sda.dt_offset,
align_sec=self._settings_sda.t_frame_start,
thr_gain=self._settings_sda.thr_gain,
)
self._frame_generator = FrameGenerator(
settings=self._settings_thr,
)
@staticmethod
def _sda_normal(xin: np.ndarray) -> np.ndarray:
return xin
def _sda_neo(self, xin: np.ndarray) -> np.ndarray:
ksda0 = self._settings_sda.dx_sda[0]
x_neo0 = xin[ksda0:-ksda0] ** 2 - xin[: -2 * ksda0] * xin[2 * ksda0 :]
return np.concatenate([x_neo0[:ksda0,], x_neo0, x_neo0[-ksda0:,]], axis=None)
def _sda_mteo(self, xin: np.ndarray) -> np.ndarray:
x_mteo = np.zeros(shape=(len(self._settings_sda.dx_sda), xin.size))
for idx, ksda0 in enumerate(self._settings_sda.dx_sda):
x0 = np.power(xin[ksda0:-ksda0,], 2) - xin[: -2 * ksda0,] * xin[2 * ksda0 :,]
x_mteo[idx, :] = np.concatenate([x0[:ksda0,], x0, x0[-ksda0:,]], axis=None)
return np.max(x_mteo, axis=0)
def _sda_ado(self, xin: np.ndarray) -> np.ndarray:
ksda0 = self._settings_sda.dx_sda[0]
x_sda = np.absolute(xin[ksda0:,] - xin[:-ksda0,])
return np.concatenate([x_sda[:ksda0], x_sda], axis=None)
def _sda_aso(self, xin: np.ndarray) -> np.ndarray:
ksda0 = self._settings_sda.dx_sda[0]
x_sda = xin[ksda0:,] * (xin[ksda0:,] - xin[:-ksda0,])
return np.concatenate([x_sda[:ksda0], x_sda], axis=None)
def _sda_eed(self, xin: np.ndarray, f_hp: float) -> np.ndarray:
filter = iirfilter(
N=2,
Wn=2 * f_hp / self._settings_sda.sampling_rate,
ftype="butter",
btype="highpass",
analog=True,
output="ba",
)
return np.square(np.array(lfilter(filter[0], filter[1], xin)))
def _sda_spb(self, xin: np.ndarray, f_bp: list) -> np.ndarray:
filter = iirfilter(
N=2,
Wn=2 * np.array(f_bp) / self._settings_sda.sampling_rate,
ftype="butter",
btype="bandpass",
analog=False,
output="ba",
)
filt0 = lfilter(filter[0], filter[1], xin)
return np.abs(filt0)
[docs]
def get_methods_sda(self) -> list:
"""Function for getting a list with all methods for spike detection"""
split_key = "_sda_"
return [method.split(split_key)[-1] for method in dir(self) if split_key in method]
[docs]
def apply_spike_detection(self, xraw: np.ndarray, **kwargs) -> np.ndarray:
"""Applying spike detection algorithm (SDA) on transient raw signal
:param xraw: Numpy array with transient raw data
:return: Numpy array with transient threshold value for extracting spike waveforms
"""
if len(self._settings_sda.dx_sda) < 1:
raise ValueError("Length of dx_sda must be greater than 1")
if self._settings_sda.dx_sda[0] < 1:
raise ValueError("Value of dx_sda[0] must be greater than 1")
if self._settings_sda.mode_sda == "eed" and "f_hp" not in kwargs.keys():
raise TypeError(
"EED method needs the definition of 'f_hp' (high-pass corner "
"frequency as float, like f_hp=150.) as kwargs"
)
if self._settings_sda.mode_sda == "spb" and "f_bp" not in kwargs.keys():
raise TypeError(
"SPB method needs the definition of 'f_bp' (band-pass corner "
"frequencies as tuple/list[float, float], like f_bp=[100., 1000.]) as kwargs"
)
method = f"_sda_{self._settings_sda.mode_sda.lower()}"
if method in self.get_methods_sda():
raise ValueError(
f"Spike Detection Method '{self._settings_sda.mode_align.lower()}' is not in {self.get_methods_sda()}. Please change!"
)
return getattr(self, method)(xraw, **kwargs)