import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from denspp.offline.plot_helper import (
cm_to_inch,
save_figure,
get_plot_color
)
[docs]
def results_autoencoder_training(
path: str,
yin: np.ndarray, ypred: np.ndarray, ymean: np.ndarray,
feat: np.ndarray, yclus: np.ndarray,
cl_dict=None, xframes: int=50, num_feat: int=3, show_plot: bool=False
) -> None:
"""Plotting results from Autoencoder Training for Neural Spike Sorting
Args:
path: Path for saving the results of plots
yin: Input signal with neural spike frames
ypred: Predicted classes
ymean: Mean waveform of all spike classes
feat: Numpy array with features for plotting feature space
yclus: Numpy array with cluster results
cl_dict: Dict with class labels
xframes: Size of spike frames
num_feat: Number of features
show_plot: Showing plots [Default: False]
Returns:
None
"""
data_labeled = True
# --- Pre-Processing
yclus0 = yclus.flatten() if len(yclus.shape) == 2 else yclus
cluster_no = np.unique(yclus0)
mark_feat = [[] for idx in range(0, num_feat)]
take_frames = list()
for i, id in enumerate(cluster_no):
pos = np.where(yclus0 == id)[0]
# Take only X frames per cluster
np.random.shuffle(pos)
take_frames.append(pos[:xframes])
# Separating the features for plotting
for idx in range(0, num_feat):
mark_feat[idx].append(feat[pos, idx])
# --- Plotting: Inference model
plot_autoencoder_run(
mark_feat, [0, 1], yin, ypred, ymean,
cluster_no, take_frames, data_labeled=data_labeled, data_classname=cl_dict, path2save=path
)
plot_autoencoder_run(
mark_feat, [0, 2], yin, ypred, ymean,
cluster_no, take_frames, data_labeled=data_labeled, data_classname=cl_dict, path2save=path
)
plot_autoencoder_run(
mark_feat, [1, 2], yin, ypred, ymean,
cluster_no, take_frames, data_labeled=data_labeled, data_classname=cl_dict, path2save=path
)
plot_3d_featspace(
yclus0, mark_feat, [0, 1, 2], data_classname=cl_dict, path2save=path,
show_plot=show_plot
)
[docs]
def plot_3d_featspace(
labels: np.ndarray, mark_feat: list,
idx: list=(0, 1, 2), data_classname=None,
path2save: str='', show_plot: bool=False, show_ticks: bool=False
) -> None:
"""Plotting the feature space of the autoencoder
:param labels: Numpy array with labels of the dataset
:param mark_feat: Numpy array with extracted mark features of the dataset
:param idx: Numpy array with used indices of the feature space (mark_feat)
:param data_classname: Numpy array with used label names
:param path2save: Path to save the figure
:param show_plot: If true, show plot
:param show_ticks: If true, show ticks
:return: None
"""
fig = plt.figure(figsize=(cm_to_inch(14), cm_to_inch(10)))
Axes3D(fig)
ax = plt.axes(projection='3d')
fontsize_label = 12
cluster_no = np.unique(labels)
for i, id in enumerate(cluster_no):
ax.scatter3D(mark_feat[idx[0]][i], mark_feat[idx[1]][i], mark_feat[idx[2]][i],
color=get_plot_color(i), marker='.')
ax.set_xlabel('Feat[0]', fontsize=fontsize_label)
ax.set_ylabel('Feat[1]', fontsize=fontsize_label)
ax.set_zlabel('Feat[2]', fontsize=fontsize_label)
if not show_ticks:
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.set_zticklabels([])
if isinstance(data_classname, list):
if not len(data_classname) == 0:
ax.legend(data_classname)
plt.tight_layout()
# --- saving plots
if path2save:
save_figure(plt, path2save, "ai_training_feat")
if show_plot:
plt.show(block=True)
[docs]
def plot_autoencoder_run(
mark_feat: list, mark_idx: list,
frames_in: np.ndarray, frames_out: np.ndarray, frames_mean: np.ndarray,
cluster_no: np.ndarray, take_frames: list,
data_classname=None, data_labeled: bool=False,
path2save: str='', show_plot: bool=False
) -> None:
""""""
plt.figure(figsize=(cm_to_inch(16), cm_to_inch(8)))
plt.rcParams.update({'font.size': 10})
plt.subplots_adjust(hspace=0, wspace=0.5)
row = 1
col = 3
axs = list()
for idx in range(0, row * col):
axs.append(plt.subplot(row, col, 1+idx))
# Noisy input
for pos in take_frames:
axs[0].plot(np.transpose(frames_in[pos, :]), linewidth=0.5)
axs[0].set_title('Input')
axs[0].set_xlabel('Frame position')
axs[0].set_xticks(np.linspace(0, frames_in.shape[1]-1, num=6, endpoint=True, dtype=int))
# Feature extraction
for i, id in enumerate(cluster_no):
axs[1].scatter(mark_feat[mark_idx[0]][i], mark_feat[mark_idx[1]][i], color=get_plot_color(i), marker='.')
axs[1].set_title('Feature Space')
axs[1].set_ylabel(f'Feat[{mark_idx[0]}]')
axs[1].set_xlabel(f'Feat[{mark_idx[1]}]')
if isinstance(data_classname, list):
if not len(data_classname) == 0:
axs[1].legend(data_classname)
# Denoised output
if data_labeled:
for i, id in enumerate(cluster_no):
axs[2].plot(frames_mean[id, :], color=get_plot_color(i), linewidth=2)
for pos in take_frames:
axs[2].plot(np.transpose(frames_out[pos, :]), linewidth=0.5)
axs[2].set_title('Output')
axs[2].set_xlabel('Frame position')
axs[2].set_xticks(np.linspace(0, frames_mean.shape[1]-1, num=6, endpoint=True, dtype=int))
for ax in axs:
ax.grid()
ax.margins(x=0)
plt.tight_layout(pad=0.5)
# --- Saving plots
if path2save:
save_figure(plt, path2save, f"ai_training_out{mark_idx[0]}-{mark_idx[1]}")
if show_plot:
plt.show(block=True)