Source code for denspp.offline.preprocessing.thresholding

from dataclasses import dataclass
from logging import Logger, getLogger

import numpy as np


[docs] @dataclass class SettingsThreshold: """Dataclass for defining the funcs for determining properties to calculate thresholding Attributes: method: Applied method for thresholding ['const': constant given value, 'abs_mean': absolute mean value, 'mad': median absolute derivation, 'mavg', moving average, 'mavg_abs': absolute mean absolute value, 'rms_norm': Root-Mean-Squared, 'rms_move': Moving RMS, 'rms_black': RMS method used in Blackrock Neurotechnology Systems, 'welford': Welford Online Algorithm for STD Calculation] sampling_rate: Sampling rate of the transient signal [Hz] gain: Applied gain on threshold output window_sec: Window length in sec. """ method: str sampling_rate: float gain: float window_sec: float @property def window_steps(self) -> int: """Getting the stepsize of the window""" return int(self.window_sec * self.sampling_rate)
DefaultSettingsThreshold = SettingsThreshold( method="const", sampling_rate=20e3, gain=1.0, window_sec=10e-3 )
[docs] class Thresholding: def __init__(self, settings: SettingsThreshold) -> None: """Class for calculating the thresholding values based on the transient input signal :param settings: Class SettingsThreshold for configuring the properties :return: None """ self._logger: Logger = getLogger(__name__) self._settings: SettingsThreshold = settings self._methods = { "const": "_constant", "abs_mean": "_absolute_median", "mad": "_median_absolute_derivation", "mavg": "_moving_average", "mavg_abs": "_moving_absolute_average", "rms_norm": "_root_mean_squared_normal", "rms_black": "_root_mean_squared_blackrock", "welford": "_welford_online", "wins": "_winsorization", }
[docs] def get_overview(self) -> list: """Getting an overview of available thresholding methods :return: List with names of available methods """ avai_methods = [key.lower() for key in self._methods.keys()] return avai_methods
[docs] def print_overview(self) -> None: self._logger.info(f"Available Thresholding methods: {self.get_overview()}")
[docs] def get_threshold(self, xin: np.ndarray, do_abs: bool = False, **kwargs) -> np.ndarray: """Function for getting the thresholding value from input :param xin: Numpy array with transient raw signal :param do_abs: Apply absolute xin for thresholding or not :return: Numpy array with thresholding value from applied method """ if self._settings.method.lower() == "const" and "thr_val" not in kwargs.keys(): raise TypeError( "Constant threshold method needs the definition of 'thr_val' (threshold value) " "as float, like thr_val=0.5 in kwargs" ) if self._settings.method.lower() not in self.get_overview(): raise ValueError( f"Thresholding method {self._settings.method} not available - Please change to {self.get_overview()}" ) xin0 = np.abs(xin) if do_abs else xin return getattr(self, self._methods[self._settings.method])(xin0, **kwargs)
[docs] def get_threshold_position( self, xin: np.ndarray, pre_time: float = 0.0, do_abs: bool = False, **kwargs ) -> np.ndarray: """Function for getting the crosspoints of thresholding value and transient input :param xin: Numpy array with transient raw signal :param pre_time: Floating value with pre-time in the window before event is detected [s] :param do_abs: Boolean for applying absolute xin for getting position and threshold :return: Numpy array with thresholding value from applied method """ xin0 = np.abs(xin) if do_abs else xin xthr = self.get_threshold(xin0, do_abs, **kwargs) if xthr.min() < 0: pos = np.argwhere(xin0 < xthr).flatten() else: pos = np.argwhere(xin0 >= xthr).flatten() pos_pre = int(self._settings.sampling_rate * pre_time) return np.array(self._get_values_non_incremented_change(pos)) - pos_pre
@staticmethod def _get_values_non_incremented_change(data: np.ndarray) -> list: """Returns values that are not incremented by one from the previous value. Always includes the first element. """ data0 = data.tolist() if not data0: return [] else: return [data0[0]] + [data0[i] for i in range(1, len(data0)) if data0[i] != data0[i - 1] + 1] def _constant(self, xin: np.ndarray, thr_val: float) -> np.ndarray: return np.zeros_like(xin) + thr_val def _absolute_median(self, xin: np.ndarray) -> np.ndarray: return np.zeros_like(xin) + self._settings.gain * np.median(np.abs(xin), axis=0) def _median_absolute_derivation(self, xin: np.ndarray) -> np.ndarray: median = np.median(xin, axis=0, keepdims=True) mad = np.median(np.abs(xin - median), axis=0, keepdims=True) std_estimate = mad / 0.6745 threshold = self._settings.gain * std_estimate return np.zeros_like(xin) + threshold def _moving_average(self, xin: np.ndarray) -> np.ndarray: M = self._settings.window_steps conv = np.convolve(xin, np.ones(M) / M, mode="same") return self._settings.gain * conv def _moving_absolute_average(self, xin: np.ndarray) -> np.ndarray: M = self._settings.window_steps conv = np.convolve(np.abs(xin), np.ones(M) / M, mode="same") return self._settings.gain * conv def _root_mean_squared_normal(self, xin: np.ndarray) -> np.ndarray: return np.zeros_like(xin) + self._settings.gain * np.sqrt(np.sum(xin**2) / xin.size) def _root_mean_squared_blackrock(self, xin: np.ndarray) -> np.ndarray: return 4.5 * self._root_mean_squared_normal(xin) def _welford_online(self, xin: np.ndarray) -> np.ndarray: n = 0 mean = 0.0 sigma = 0.0 std_out = np.zeros_like(xin) for idx, x in enumerate(xin): n += 1 mean_old = mean mean += (x - mean) / n sigma += ((x - mean) * (x - mean_old) - sigma) / n std_out[idx] = sigma std_out[0:1] = std_out[2] return self._settings.gain * np.sqrt(std_out)