Source code for denspp.offline.dnn.plots.plot_dataset

import numpy as np
from matplotlib import pyplot as plt

from denspp.offline.dnn import DatasetFromFile
from denspp.offline.metric import calculate_snr
from denspp.offline.plot_helper import (
    cm_to_inch,
    get_textsize_paper,
    save_figure,
    scale_auto_value,
)


[docs] def plot_frames_dataset( data: DatasetFromFile, take_samples: int = 500, do_norm: bool = False, add_subtitle: bool = False, path2save: str = "", show_plot: bool = False, ) -> None: """Plotting the frames for different classes into one figure Args: data: Dictionary with spike frames, peak amplitudes and labels take_samples: Only take random N samples from each class do_norm: Do normalization (minmax) add_subtitle: Adding a subtitle with further information path2save: Path to save figure show_plot: Plot option for blocking and showing the plots Return: None """ frame_raw = data.data frame_true = data.label frame_dict = data.dict frame_peak = np.max(np.abs(data.data), axis=1) scaley, unity = scale_auto_value(data.data) # --- Figure #1: Spike Frames num_rows = 2 num_cols = int(np.ceil(len(frame_dict) / num_rows)) cluster_id, cluster_cnt = np.unique(frame_true, return_counts=True) fig, axs = plt.subplots( num_rows, num_cols, sharex=True, sharey=True, figsize=(cm_to_inch(14), cm_to_inch(16)), ) for id, key in enumerate(frame_dict): pos_id = np.argwhere(frame_true == id).flatten() pos_random = pos_id[np.random.randint(0, pos_id.size, take_samples)] scale = ( 1 if not do_norm else np.repeat(np.expand_dims(frame_peak[pos_random], -1), frame_raw.shape[-1], -1) ) frame_raw0 = frame_raw[pos_random, :] / scale frame_mean = np.mean(frame_raw0, axis=0) snr_class = list() for frame in frame_raw0: snr_class.append(calculate_snr(frame, frame_mean)) axs[int(id / num_cols), id % num_cols].plot(scaley * np.transpose(frame_raw0), linewidth=0.5) axs[int(id / num_cols), id % num_cols].plot(scaley * frame_mean, "k", linewidth=2.0) axs[int(id / num_cols), id % num_cols].grid() if not add_subtitle: axs[int(id / num_cols), id % num_cols].set_title(key, fontsize=get_textsize_paper()) else: text = f"{key}\nnum={cluster_cnt[id]}, median(SNR) = {np.mean(np.array(snr_class)):.2f} dB" axs[int(id / num_cols), id % num_cols].set_title(text, fontsize=get_textsize_paper()) axs[0, 0].set_xlim([0, frame_raw.shape[-1] - 1]) axs[0, 0].set_xticks(np.linspace(0, frame_raw.shape[-1] - 1, 7, endpoint=True, dtype=np.uint16)) if do_norm: axs[0, 0].set_ylabel("Spike Norm. Value", fontsize=get_textsize_paper()) axs[1, 0].set_ylabel("Spike Norm. Value", fontsize=get_textsize_paper()) else: axs[0, 0].set_ylabel(f"Spike Voltage [{unity}V]", fontsize=get_textsize_paper()) axs[1, 0].set_ylabel(f"Spike Voltage [{unity}V]", fontsize=get_textsize_paper()) for idx in range(num_cols): axs[1, idx].set_xlabel("Spike Frame Position", fontsize=get_textsize_paper()) plt.tight_layout() # --- Figure #2: Histogram - Spike Frame Peak Amplitude _, axs = plt.subplots(1, 2, sharex=True) axs[0].hist(frame_peak, bins=101) axs[1].hist(frame_peak, bins=101, density=True, cumulative=True) axs[0].set_xlim([0, frame_peak.max()]) axs[0].set_ylabel("Bins", fontsize=get_textsize_paper()) axs[0].set_xlabel(f"Spike Peak Amplitude [{unity}V]", fontsize=get_textsize_paper()) axs[1].set_ylabel("Density", fontsize=get_textsize_paper()) axs[1].set_xlabel(f"Spike Peak Amplitude [{unity}V]", fontsize=get_textsize_paper()) for ax in axs: ax.grid() plt.tight_layout() if path2save: save_figure(fig, path=path2save, name="dataset_frames") if show_plot: plt.show(block=True)
[docs] def plot_mnist_dataset( data: np.ndarray, label: np.ndarray, title: str = "", path2save: str = "", show_plot: bool = False, ) -> None: """Plotting examples of all labels from MNIST dataset :param data: Numpy array with data content :param label: Numpy array with label content :param title: String with title appendix :param path2save: Path to save figure :param show_plot: Boolean for showing plot :return: None """ plt.figure() axs = [plt.subplot(3, 3, idx + 1) for idx in range(9)] print(label, type(label), type(label[0])) for idx, ax in enumerate(axs): pos_num = int(np.argwhere(label == idx).flatten()[0]) ax.imshow(data[pos_num], cmap=plt.get_cmap("gray")) ax.set_title(f"Label = {idx}") ax.set_xticks([]) ax.set_yticks([]) plt.tight_layout() if path2save: save_figure(plt, path=path2save, name=f"dataset_mnist{title}") if show_plot: plt.show(block=True)
[docs] def plot_waveforms_dataset( dataset: DatasetFromFile, num_samples_class: int = 1, path2save: str = "", show_plot: bool = True, ) -> None: """Plotting examples of all labels from waveform dataset :param dataset: Dataclass with loaded dataset using DatasetLoader :param num_samples_class: Integer with number of samples to plot from dataset (per class) :param path2save: Path to save figure :param show_plot: Boolean for showing plot :return: None """ class_id = np.unique(dataset.label) fig, axs = plt.subplots(3, int(np.ceil(class_id.size / 3)), sharex=True) use_subplot = [False for _ in range(axs.shape[0] * axs.shape[1])] for idx in class_id: use_subplot[idx] = True num_column = int(np.floor(idx / axs.shape[1])) val_range = np.argwhere(dataset.label == idx).flatten() sample_idx = np.unique( np.random.randint(low=val_range[0], high=val_range[-1], size=num_samples_class) ) axs[num_column, idx % axs.shape[1]].plot( dataset.data[sample_idx, :].T, label=dataset.dict[idx], color="gray" ) axs[num_column, idx % axs.shape[1]].plot( np.mean(dataset.data[val_range, :], axis=0), label=dataset.dict[idx], color="k", ) axs[num_column, idx % axs.shape[1]].set_title(f"{dataset.dict[idx]}") axs[num_column, idx % axs.shape[1]].axis("off") for idx in [i for i, val in enumerate(use_subplot) if not val]: num_column = int(np.floor(idx / axs.shape[1])) axs[num_column, idx % axs.shape[1]].axis("off") plt.tight_layout() if path2save: save_figure(fig, path=path2save, name="dataset_waveforms") if show_plot: plt.show(block=True)