import matplotlib.pyplot as plt
import numpy as np
from denspp.offline.plot_helper import (
cm_to_inch,
get_plot_color,
get_plot_color_inactive,
get_textsize_paper,
save_figure,
scale_auto_value,
)
[docs]
def plot_frames_feature(
signals: dict,
no_electrode: int,
take_feat_dim: list = (0, 1),
path: str = "",
show_plot: bool = False,
) -> None:
"""Plotting the detected spike frame activity of used transient data
:param signals: class containing the rawdata and processed data from class PipelineSignal
:param no_electrode: number of electrodes
:param take_feat_dim: List with dimension selection for plotting the 2d feature space
:param path: Path to save the figures
:param show_plot: If true, show plot
:return: None
"""
assert len(take_feat_dim) == 2, "take_feat_dim must be 2 dimensional"
frames_out = signals["frames"][0]
cluster = signals["frames"][2]
assert frames_out.shape[0] == cluster.size, (
"Dimensions between number of frames and corresponding cluster ID are not equal"
)
feat = signals["features"]
frames_mean = np.zeros(shape=(len(np.unique(cluster)), frames_out.shape[1]))
for idx, id in enumerate(np.unique(cluster)):
idx = np.argwhere(cluster == id).flatten()
frames_mean[id, :] = np.mean(frames_out[idx], axis=0)
plt.figure(figsize=(cm_to_inch(20), cm_to_inch(10)))
plt.subplots_adjust(hspace=0)
ax1 = plt.subplot(131)
ax2 = plt.subplot(132)
ax3 = plt.subplot(133, sharex=ax1)
ax1.set_title("Aligned Frames")
ax1.plot(np.transpose(frames_out), marker=".", markersize=4, drawstyle="steps-post")
ax2.set_title("Feature Space")
for id in np.unique(cluster):
idx = np.argwhere(cluster == id).flatten()
ax2.plot(
feat[idx, take_feat_dim[0]],
feat[idx, take_feat_dim[1]],
color=get_plot_color(id),
marker=".",
linestyle="none",
)
ax2.set_ylabel("Feat. 1")
ax2.set_xlabel("Feat. 2")
ax3.set_title("Mean Frames (Clustered)")
for idx, frame in enumerate(frames_mean):
ax3.plot(
np.transpose(frame),
color=get_plot_color(idx),
marker=".",
markersize=4,
drawstyle="steps-post",
)
plt.tight_layout()
# --- saving plots
if path:
save_figure(plt, path, f"pipeline_features_elec{str(no_electrode)}")
if show_plot:
plt.show(block=True)
[docs]
def plot_transient_highlight_spikes(
signals: dict,
no_electrode: int,
path: str = "",
time_cut: list = (),
show_noise: bool = False,
show_plot: bool = False,
) -> None:
"""Plotting the detected spike activity from transient data (highlighted, noise in gray)
:param signals: class containing the rawdata and processed data from class PipelineSignal
:param no_electrode: number of electrodes
:param path: Path to save the figures
:param time_cut: List for only specified range
:param show_noise: If true, show noise (otherwise flat line)
:param show_plot: If true, show plot
:return: None
"""
fs_dig = signals["fs_dig"]
xadc = signals["x_adc"]
time = np.arange(0, xadc.size, 1) / fs_dig
ticks = signals["frames"][1]
ticks_id = signals["frames"][2]
time0 = list()
tran0 = list()
colo0 = list()
tick_old = 0
for idx, tick in enumerate(ticks):
sel = [int(tick) - 12, int(tick) + 30]
time0.append(time[tick_old : sel[0]])
time0.append(time[sel[0] : sel[1]])
tran0.append(
xadc[tick_old : sel[0]]
if show_noise
else np.zeros(shape=(len(xadc[tick_old : sel[0]]),), dtype=int)
)
tran0.append(xadc[sel[0] : sel[1]])
colo0.append(get_plot_color_inactive())
colo0.append(get_plot_color(ticks_id[idx]))
tick_old = sel[1]
# --- Plot generation
plt.figure(figsize=(cm_to_inch(16), cm_to_inch(13)))
# plt.subplots_adjust(hspace=0)
axs = list()
for idx in range(0, 1):
axs.append(plt.subplot(1, 2, 1 + 2 * idx))
axs.append(plt.subplot(1, 2, 2 + 2 * idx, sharey=axs[2 * idx]))
# Subplot 1: Transient signal (colored)
for idx, time1 in enumerate(time0):
axs[0].plot(time1, tran0[idx], linewidth=1, color=colo0[idx], drawstyle="steps-post")
# --- Subplot 2: Histogram (from Subplot 1)
no_bins = 1 + abs(max(xadc)) + abs(min(xadc))
if not len(time_cut) == 0:
sel0 = np.argwhere(time >= time_cut[0]).flatten()[0]
sel1 = np.argwhere(time >= time_cut[1]).flatten()[0] - 1
x_bins = xadc[sel0:sel1]
else:
x_bins = xadc
x_nonzero = np.where(x_bins != 0)[0]
axs[1].hist(
xadc[x_nonzero],
color="k",
density=True,
log=True,
bins=no_bins,
orientation="horizontal",
)
# --- Axis test
axs[0].set_xlabel("Time t [s]")
axs[0].set_ylabel("x_adc(t) [ ]")
axs[0].grid()
axs[1].set_xlabel("Density")
axs[1].grid()
# --- Zooming
if not len(time_cut) == 0:
axs[0].set_xlim(time_cut)
addon_zoom = "_zoom"
else:
axs[0].set_xlim([time[0], time[-1]])
addon_zoom = ""
plt.tight_layout()
# --- saving plots
if path:
save_figure(plt, path, f"pipeline_spikes_elec{str(no_electrode)}{addon_zoom}")
if show_plot:
plt.show(block=True)
[docs]
def plot_mea_transient_total(
mea_data: np.ndarray,
mapping: np.ndarray,
fs_used: float,
path2save: str = "",
do_global_limit: bool = False,
do_show: bool = False,
) -> None:
"""Plotting the transient signals of the transient numpy signal with electrode information
Args:
mea_data: Transient numpy array with neural signal [row, colomn, transient]
mapping: Numpy array with electrode mapping information
fs_used: Sampling rate of the signal [Hz]
path2save: Path for saving the figures
do_global_limit: Doing a global y-range setting
do_show: Show the plots
Returns:
None
"""
assert mea_data.shape[0:2] == mapping.shape, (
"Shape mismatch, please apply_mapping() using PipelineCMDs"
)
num_rows = mapping.shape[0]
num_cols = mapping.shape[1]
time_array = np.linspace(0, mea_data[0, 0].size, mea_data[0, 0].size) / fs_used
scale_yaxis = scale_auto_value(mea_data)
scale_xaxis = scale_auto_value(time_array)
# Extract maximum values for scaling
mea_yrange = np.zeros((np.sum(mapping > 0), 3), dtype=float)
idx = 0
for i in range(num_rows):
for j in range(num_cols):
if mapping[i, j] > 0:
mea_yrange[idx, 0] = scale_yaxis[0] * np.min(mea_data[i, j])
mea_yrange[idx, 1] = scale_yaxis[0] * np.max(mea_data[i, j])
mea_yrange[idx, 2] = scale_yaxis[0] * (np.max(mea_data[i, j]) - np.min(mea_data[i, j]))
idx += 1
mea_yglobal = np.zeros((2,), dtype=float)
mea_yglobal[0] = np.min(mea_yrange[:, 0])
mea_yglobal[1] = np.max(mea_yrange[:, 1])
# --- Create the figure
fig, axes = plt.subplots(num_rows, num_cols, figsize=(12, 10))
plt.subplots_adjust(left=0.02, right=0.98, top=0.95, bottom=0.02, hspace=0.07, wspace=0.07)
idx = 0
ax_empty = list()
for i in range(num_rows):
for j in range(num_cols):
ax = axes[i, j]
if mapping[i, j] < 1:
ax.plot([0], "k-", linewidth=0.1)
ax_empty.append(ax)
else:
ax.plot(
scale_xaxis[0] * time_array,
scale_yaxis[0] * mea_data[i, j],
"k-",
linewidth=1.0,
)
ax.set_xlim([scale_xaxis[0] * time_array[0], scale_xaxis[0] * time_array[-1]])
yrange_used = (
mea_yglobal.tolist() if do_global_limit else [mea_yrange[idx, 0], mea_yrange[idx, 1]]
)
ax.set_ylim(yrange_used)
idx += 1
# Remove x-/y-axis ticks and labels
ax.set_yticklabels([])
ax.set_yticks([])
ax.set_xticklabels([])
ax.set_xticks([])
# Remove subplot border
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.spines["left"].set_visible(False)
ax.spines["bottom"].set_visible(False)
if len(ax_empty):
ax_empty[0].arrow(
x=-0.35,
y=0.5,
dx=1,
dy=0,
length_includes_head=True,
head_width=0.08,
head_length=0.00002,
)
ax_empty[0].text(
x=0.15,
y=0.55,
s=f"{scale_xaxis[0] * time_array[-1]:.1f} {scale_xaxis[1]}s",
ha="center",
)
ax_empty[0].arrow(
x=-0.35,
y=0,
dx=0,
dy=1,
length_includes_head=True,
head_width=0.08,
head_length=0.00002,
)
ax_empty[0].text(
x=-0.4,
y=0.45,
s=f"{scale_yaxis[0] * (mea_yglobal[1] - mea_yglobal[0]):.1f} {scale_yaxis[1]}V",
ha="center",
rotation=90,
)
if path2save:
save_figure(plt, path2save, "mea_data" + ("_global" if do_global_limit else "_local"))
if do_show:
plt.show()