Source code for denspp.offline.metric.snr

import numpy as np
from torch import Tensor, max, min, sum, log10, sub


[docs] def calculate_snr(yin: np.ndarray, ymean: np.ndarray) -> np.ndarray: """Calculating the signal-to-noise ratio [dB] of the input signal compared to mean waveform :param yin: Numpy array with all spike waveforms (raw data) :param ymean: Numpy array with mean waveform of corresponding spike frame cluster :return: Numpy array with SNR of all spike waveforms """ a0 = (np.max(ymean) - np.min(ymean)) ** 2 b0 = np.sum((yin - ymean) ** 2) return 10 * np.log10(a0 / b0)
[docs] def calculate_snr_tensor(data: Tensor, mean: Tensor) -> Tensor: """Calculating the Signal-to-Noise (SNR) ratio of the input data Args: data: Tensor with raw data / frame mean: Tensor with class-specific mean data / frame Return: Tensor with SNR value """ max_values, _ = max(mean, dim=1) min_values, _ = min(mean, dim=1) a0 = (max_values - min_values) ** 2 b0 = sum((data - mean) ** 2, dim=1) return 10 * log10(a0 / b0)
[docs] def calculate_snr_tensor_waveform(input_waveform: Tensor, mean_waveform: Tensor) -> Tensor: """Calculation of metric Signal-to-Noise ratio (SNR) of defined input and reference waveform Args: input_waveform: Tensor array with input waveform mean_waveform: Tensor array with real mean waveform from dataset Return: Tensor with differential Signal-to-Noise ratio (SNR) of applied waveforms """ return calculate_snr_tensor(input_waveform, mean_waveform)
[docs] def calculate_dsnr_tensor_waveform(input_waveform: Tensor, pred_waveform: Tensor, mean_waveform: Tensor) -> Tensor: """Calculation of metric different Signal-to-Noise ratio (SNR) between defined input and predicted to reference waveform Args: input_waveform: Tensor array with input waveform pred_waveform: Tensor array with predicted waveform from model mean_waveform: Tensor array with real mean waveform from dataset Return: Tensor with differential Signal-to-Noise ratio (SNR) of applied waveforms """ snr_in = calculate_snr_tensor(input_waveform, mean_waveform) snr_out = calculate_snr_tensor(pred_waveform, mean_waveform) return sub(snr_out, snr_in)