Source code for denspp.offline.dnn.plots.plot_dataset

import numpy as np
from matplotlib import pyplot as plt
from denspp.offline.plot_helper import cm_to_inch
from denspp.offline.metric import calculate_snr


[docs] def plot_frames_dataset(data: dict, take_samples: int=500, do_norm: bool=False, plot_norm: bool=False, plot_show: bool=False, add_subtitle: bool=False) -> None: """Plotting the frames for different classes into one figure Args: data: Dictionary with spike frames, peak amplitudes and labels take_samples: Only take random N samples from each class do_norm: Do normalization (minmax) plot_norm: Plot option for do normalization on input data plot_show: Plot option for blocking and showing the plots add_subtitle: Adding a subtitle with further informations Return: None """ frame_raw = data['data'] frame_true = data['label'] frame_dict = data['dict'] frame_peak = data['peak'] / np.abs(frame_raw.min()) if 'peak' in data.keys() else np.max(np.abs(data['data']), 1) # --- Figure #1: Spike Frames num_rows = 2 num_cols = int(np.ceil(len(frame_dict)/num_rows)) cluster_id, cluster_cnt = np.unique(frame_true, return_counts=True) _, axs = plt.subplots(num_rows, num_cols, sharex=True, sharey=True, figsize=(cm_to_inch(14), cm_to_inch(16))) for id, key in enumerate(frame_dict): pos_id = np.argwhere(frame_true == id).flatten() pos_random = pos_id[np.random.randint(0, pos_id.size, take_samples)] scale = 1 if not do_norm else np.repeat(np.expand_dims(frame_peak[pos_random], -1), frame_raw.shape[-1], -1) frame_raw0 = frame_raw[pos_random, :] / scale frame_mean = np.mean(frame_raw0, axis=0) snr_class = list() for frame in frame_raw0: snr_class.append(calculate_snr(frame, frame_mean)) axs[int(id/num_cols), id % num_cols].plot(np.transpose(frame_raw0), linewidth=0.5) axs[int(id/num_cols), id % num_cols].plot(frame_mean, 'k', linewidth=2.0) axs[int(id/num_cols), id % num_cols].grid() if not add_subtitle: axs[int(id/num_cols), id % num_cols].set_title(key, fontsize=13) else: text = f'{key}\nnum={cluster_cnt[id]}, median(SNR) = {np.mean(np.array(snr_class)):.2f} dB' axs[int(id / num_cols), id % num_cols].set_title(text, fontsize=13) axs[0, 0].set_xlim([0, frame_raw.shape[-1]-1]) axs[0, 0].set_xticks(np.linspace(0, frame_raw.shape[-1]-1, 7, endpoint=True, dtype=np.uint16)) if plot_norm: axs[0, 0].set_ylabel("Spike Norm. Value", fontsize=13) axs[1, 0].set_ylabel("Spike Norm. Value", fontsize=13) else: axs[0, 0].set_ylabel("Spike Voltage [µV]", fontsize=13) axs[1, 0].set_ylabel("Spike Voltage [µV]", fontsize=13) for idx in range(num_cols): axs[1, idx].set_xlabel("Spike Frame Position", fontsize=13) plt.tight_layout() # --- Figure #2: Histogram - Spike Frame Peak Amplitude _, axs = plt.subplots(1, 2, sharex=True) axs[0].hist(frame_peak, bins=101) axs[1].hist(frame_peak, bins=101, density=True, cumulative=True) axs[0].set_xlim([0, frame_peak.max()]) axs[0].set_ylabel('Bins', fontsize=13) axs[0].set_xlabel('Spike Peak Amplitude [µV]', fontsize=13) axs[1].set_ylabel('Density', fontsize=13) axs[1].set_xlabel('Spike Peak Amplitude [µV]', fontsize=13) for ax in axs: ax.grid() plt.tight_layout() if plot_show: plt.show(block=True)