Source code for denspp.offline.metric.snr
import numpy as np
from torch import Tensor, max, min, sum, log10, sub, div
[docs]
def calculate_snr(data: np.ndarray, mean: np.ndarray) -> np.ndarray:
"""Calculating the signal-to-noise ratio [dB] of the input signal compared to mean waveform
:param data: Numpy array with all spike waveforms (raw data)
:param mean: Numpy array with mean waveform of corresponding spike frame cluster
:return: Numpy array with SNR of all spike waveforms
"""
a0 = (np.max(mean) - np.min(mean)) ** 2
b0 = np.sum((data - mean) ** 2)
return 10 * np.log10(a0 / b0)
[docs]
def calculate_snr_tensor(data: Tensor, mean: Tensor) -> Tensor:
"""Calculating the Signal-to-Noise (SNR) ratio of the input data
Args:
data: Tensor with raw data / frame
mean: Tensor with class-specific mean data / frame
Return:
Tensor with SNR value
"""
max_values, _ = max(mean, dim=1)
min_values, _ = min(mean, dim=1)
a0 = (max_values - min_values) ** 2
b0 = sum((data - mean) ** 2, dim=1)
return 10 * log10(div(a0, b0))
[docs]
def calculate_dsnr_tensor(data: Tensor, pred: Tensor, mean: Tensor) -> Tensor:
"""Calculation of metric different Signal-to-Noise ratio (SNR) between defined input and predicted to reference waveform
Args:
data: Tensor array with input waveform
pred: Tensor array with predicted waveform from model
mean: Tensor array with real mean waveform from dataset
Return:
Tensor with differential Signal-to-Noise ratio (SNR) of applied waveforms
"""
snr_in = calculate_snr_tensor(data, mean)
snr_out = calculate_snr_tensor(pred, mean)
return sub(snr_out, snr_in)
[docs]
def calculate_snr_cluster(frames_in: np.ndarray, frames_cl: np.ndarray, frames_mean: np.ndarray) -> np.ndarray:
"""Calculating the cluster-specific Signal-to-Noise Ratio (SNR) for all frames
:param frames_in: Numpy array with spike frames
:param frames_cl: Numpy array with cluster label to each spike frame
:param frames_mean: Numpy array with mean waveforms of cluster
:return: Numpy array with SNR value for each sample for {min, mean, max}
"""
id_cluster, num_cluster = np.unique(frames_cl, return_counts=True)
cluster_snr = np.zeros(shape=(num_cluster.size, 4), dtype=float)
for idx, id in enumerate(id_cluster):
indices = np.where(frames_cl == id)[0]
snr0 = np.zeros(shape=(indices.size,), dtype=float)
for i, frame in enumerate(frames_in[indices, :]):
snr0[i] = calculate_snr(frame, frames_mean[id, :])
cluster_snr[idx, 0] = np.min(snr0)
cluster_snr[idx, 1] = np.mean(snr0)
cluster_snr[idx, 2] = np.max(snr0)
return cluster_snr