Source code for elasticai.preprocessor.waveform_generator
from dataclasses import dataclass
from datetime import datetime
from logging import Logger, getLogger
from pathlib import Path
from typing import Callable
import numpy as np
from elasticai.creator.arithmetic import FxpArithmetic, FxpParams, int_converter
from elasticai.creator_plugins.bram.utils import translate_path_to_int, write_mem_file
from scipy import signal
import elasticai.creator_plugins.waveform.utils as hw_utils
from elasticai.creator_plugins.waveform.c.waveform_lut_c import generate_waveform_lut_template
from elasticai.preprocessor._check_funcs import check_keylist_elements_any
from elasticai.preprocessor.translation.ir2c import (
generate_c_files,
get_embedded_datatype,
replace_variables_with_parameters,
)
[docs]
@dataclass(frozen=True)
class WaveformSignal:
"""Dataclass with waveform signal
Attributes:
time: Numpy array with timestamps
signal: Numpy array with signal
fs: Float with sampling rate
rms: Float with root mean square value
"""
time: np.ndarray
signal: np.ndarray
fs: float
rms: float
[docs]
class WaveformGenerator:
_logger: Logger
def __init__(self, sampling_rate: float):
"""Class for generating the transient stimulation signal
:param sampling_rate: Sampling rate of the signal
"""
self._logger = getLogger(__name__)
self._sampling_rate: float = sampling_rate
self._time_duration: float = 1.0
self.__func_dict: dict[str, Callable] = {"RECT_HALF": self.__generate_rectangular_half}
self.__func_dict.update({"RECT_FULL": self.__generate_rectangular_full})
self.__func_dict.update({"LIN_RISE": self.__generate_linear_rising})
self.__func_dict.update({"LIN_FALL": self.__generate_linear_falling})
self.__func_dict.update({"SINE_HALF": self.__generate_sinusoidal_half})
self.__func_dict.update({"SINE_HALF_INV": self.__generate_sinusoidal_half_inverse})
self.__func_dict.update({"SINE_FULL": self.__generate_sinusoidal_full})
self.__func_dict.update({"TRI_HALF": self.__generate_triangle_half})
self.__func_dict.update({"TRI_FULL": self.__generate_triangle_full})
self.__func_dict.update({"SAW_POS": self.__generate_sawtooth_positive})
self.__func_dict.update({"SAW_NEG": self.__generate_sawtooth_negative})
self.__func_dict.update({"GAUSS": self.__generate_gaussian})
self.__func_dict.update({"ZERO": self.__generate_zero})
self.__func_dict.update({"EAP": self.__generate_spike_waveform})
@property
def _num_samples(self) -> int:
"""Calculating the number of samples of the transient window"""
return int(self._time_duration * self._sampling_rate)
@property
def _build_time_cycle(self) -> np.ndarray:
return np.linspace(
start=0.0,
stop=2 * np.pi,
num=self._num_samples,
endpoint=False,
dtype=float,
)
@staticmethod
def __switching_polarity(signal_in: np.ndarray, do_cathodic: bool) -> np.ndarray:
"""Switching the polarity for cathodic-first (True) or anodic-first (False) waveform"""
return signal_in if not do_cathodic else (-1) * signal_in
def __get_charge_balancing_factor(self, waveforms: list) -> float:
"""Getting the coefficient for area-related comparison for charge balancing the biphasic waveform"""
if not len(waveforms) == 2 and not len(waveforms) == 3:
self._logger.info("It is not a biphasic waveform available - Please check!")
return 1.0
else:
area_first = np.trapezoid(waveforms[0])
area_second = np.trapezoid(waveforms[-1])
return float(np.abs(area_first / area_second))
[docs]
def check_charge_balancing(self, signal: np.ndarray) -> float:
"""Checking if stimulation signal is charge balanced"""
dq = np.trapezoid(signal)
self._logger.info(f"... waveform has an error of {dq:.6f}")
return float(dq)
def __generate_zero(self) -> np.ndarray:
out = np.zeros((self._num_samples,), dtype=float)
return out
def __generate_rectangular_half(self) -> np.ndarray:
return 1.0 + self.__generate_zero()
def __generate_rectangular_full(self) -> np.ndarray:
return signal.square(self._build_time_cycle, duty=0.5)
def __generate_linear_rising(self) -> np.ndarray:
return np.linspace(0.0, 1.0, self._num_samples, endpoint=True, dtype=float)
def __generate_linear_falling(self) -> np.ndarray:
return np.linspace(1.0, 0.0, self._num_samples, endpoint=True, dtype=float)
def __generate_sinusoidal_half(self) -> np.ndarray:
return np.sin(0.5 * self._build_time_cycle, dtype=float)
def __generate_sinusoidal_half_inverse(self) -> np.ndarray:
return 1.0 - np.sin(0.5 * self._build_time_cycle, dtype=float)
def __generate_sinusoidal_full(self) -> np.ndarray:
return np.sin(self._build_time_cycle, dtype=float)
def __generate_triangle_half(self) -> np.ndarray:
return signal.sawtooth(0.5 * self._build_time_cycle + np.pi / 2, width=0.5)
def __generate_triangle_full(self) -> np.ndarray:
return signal.sawtooth(self._build_time_cycle + np.pi / 2, width=0.5)
def __generate_sawtooth_positive(self) -> np.ndarray:
return 2 * self.__generate_linear_rising() - 1.0
def __generate_sawtooth_negative(self) -> np.ndarray:
return 2 * self.__generate_linear_falling() - 1.0
def __generate_gaussian(self) -> np.ndarray:
time = self.__generate_sawtooth_positive()
out = signal.gausspulse(time, fc=np.pi, retenv=True)[1]
scale_amp = (out.max() + out.min()) / (out.max())
return out * scale_amp - out.min()
def __generate_spike_waveform(self) -> np.ndarray:
t_end_ms = 1.6
t = np.linspace(
start=0.0,
stop=t_end_ms,
num=int(t_end_ms * self._sampling_rate * 1e-3),
endpoint=True,
)
eap0 = -np.exp(-((t - 0.45) ** 2) / 0.03)
eap1 = 0.5 * np.exp(-((t - 0.86) ** 2) / 0.08)
eap = eap0 + eap1
eap = -eap / eap.min()
return eap
[docs]
def get_dictionary_classes(self) -> list:
"""Getting a list with class names / labels of waveforms
:return: List with class names
"""
return [val for val in self.__func_dict.keys()]
def __select_waveform_template(
self, time_duration: float, sel_wfg: str, do_cathodic: bool = False
) -> np.ndarray:
"""Selection for generating a waveform template
Args:
time_duration: Time window for the waveform
sel_wfg: Selected waveform type [0: rect., 1: linear-rising, 2: linear-falling, 3: half-sinusoidal,
4: half-sinusoidal (inverse), 5: full-sinusoidal, 6: half-triangular, 7: full-triangular,
8: positive sawtooth, 9: negative sawtooth, 10: gaussian]
do_cathodic: Boolean for cathodic-first impulse
Returns:
Numpy array with selected waveform
"""
if sel_wfg in self.__func_dict.keys():
self._time_duration = time_duration
signal = self.__func_dict[sel_wfg]()
waveform = self.__switching_polarity(signal, do_cathodic)
self._logger.debug(
f"Selected waveform type {sel_wfg} is generated with shape {waveform.shape}"
)
return waveform
else:
raise NotImplementedError("Waveform is not implemented!")
[docs]
def generate_waveform(
self,
time_points: list,
time_duration: list,
waveform_select: list,
polarity_cathodic: list,
) -> WaveformSignal:
"""Generating the signal with waveforms for stimulation
:param time_points: List of time points for applying a stimulation waveform
:param time_duration: List of stimulation waveform duration
:param waveform_select: List of selected waveforms
:param polarity_cathodic: List for performing cathodic-first generation
:returns: Dataclass WaveformSignal with numpy arrays ['time', output_signal, true rms value)
"""
if not len(time_points) == len(waveform_select) == len(time_duration):
raise RuntimeError("Please check input! --> Length is not equal")
else:
self._time_duration = 2 * time_points[-1] + time_duration[-1]
out = self.__generate_zero()
rms_value = 0.0
for idx, (time_off, time_sec, wvf_type) in enumerate(
zip(time_points, time_duration, waveform_select)
):
time_xpos = int(time_off * self._sampling_rate)
do_polarity = polarity_cathodic[idx] if not len(polarity_cathodic) == 0 else False
waveform = self.__select_waveform_template(time_sec, wvf_type, do_polarity)
out[time_xpos : time_xpos + waveform.size] += waveform
rms_value = np.sqrt(np.sum(np.square(waveform)) / waveform.size)
time = np.linspace(0, out.size, out.size, endpoint=False) / self._sampling_rate
return WaveformSignal(
time=time,
signal=out,
fs=self._sampling_rate,
rms=rms_value,
)
[docs]
def generate_waveform_quant_fxp(
self,
time_points: list,
time_duration: list,
waveform_select: list,
polarity_cathodic: list,
bitwidth: int,
bitfrac: int,
signed: bool,
do_opt: bool = False,
) -> WaveformSignal:
"""Generating the signal with waveforms for stimulation in quantized matter
:param time_points: List of time points for applying a stimulation waveform
:param time_duration: List of stimulation waveform duration
:param waveform_select: List of selected waveforms
:param polarity_cathodic: List for performing cathodic-first generation
:param bitwidth: Integer with total bitwidth
:param bitfrac: Integer with fraction bitwidth
:param signed: If quantized output should be signed integer
:param do_opt: Boolean for taking quarter signal (optimized version for hardware implementation)
:returns: Dataclass WaveformSignal with quantized signals ['time', 'signal', 'fs', 'rms']
"""
supported_waveform_types = ["SINE_FULL", "RECT_FULL", "TRI_FULL"]
assert check_keylist_elements_any(waveform_select, supported_waveform_types), (
f"Only 'waveform_select' with {supported_waveform_types} are allowed!"
)
wvf_norm = self.generate_waveform(
time_points=time_points,
time_duration=time_duration,
waveform_select=waveform_select,
polarity_cathodic=polarity_cathodic,
)
scale = 1.0 if signed else 0.5
offset = 0.0 if signed else 0.5
val_in = (wvf_norm.signal * scale + offset) if not do_opt else wvf_norm.signal
arith = FxpArithmetic(fxp_params=FxpParams(total_bits=bitwidth, frac_bits=bitfrac, signed=signed))
wvf_fxp = arith.round_to_rational(val_in.tolist())
wvf_fxp = np.asarray(wvf_fxp)
if do_opt:
return WaveformSignal(
time=wvf_norm.time[: int(wvf_fxp.size / 4) + 1],
signal=wvf_fxp[: int(wvf_fxp.size / 4) + 1],
fs=wvf_norm.fs,
rms=wvf_norm.rms,
)
else:
return WaveformSignal(
time=wvf_norm.time,
signal=wvf_fxp,
fs=wvf_norm.fs,
rms=wvf_norm.rms,
)
[docs]
def generate_biphasic_waveform(
self,
anodic_wvf: str,
anodic_duration: float,
cathodic_wvf: str,
cathodic_duration: float,
intermediate_duration: float = 0.0,
do_cathodic_first: bool = False,
do_charge_balancing: bool = False,
) -> dict:
"""Generating the waveform for stimulation
Args:
anodic_wvf: String with waveform type for anodic phase
anodic_duration: Time window of the anodic phase
cathodic_wvf: String with waveform type for cathodic phase
cathodic_duration: Time window of the cathodic phase
intermediate_duration: Time window for the intermediate idle time during anodic and cathodic phase
do_cathodic_first: Starting with cathodic phase
do_charge_balancing: Performing a charge balancing on second phase (same area)
Returns:
Two numpy arrays (time, output_signal)
"""
width = (
[anodic_duration, cathodic_duration]
if not do_cathodic_first
else [cathodic_duration, anodic_duration]
)
mode = [anodic_wvf, cathodic_wvf] if not do_cathodic_first else [cathodic_wvf, anodic_wvf]
poly = [False, True] if not do_cathodic_first else [True, False]
waveforms = list()
# --- Creating the waveforms
for idx, (window, wvf_type, inverter) in enumerate(zip(width, mode, poly)):
if idx == 1 and not intermediate_duration == 0.0:
self._time_duration = intermediate_duration
waveforms.append(self.__generate_zero())
waveforms.append(self.__select_waveform_template(window, wvf_type, inverter))
if do_charge_balancing:
waveform = self.__get_charge_balancing_factor(waveforms) * waveforms[-1]
waveforms[-1] = waveform
# --- Creating the output signal
out = np.concatenate([waveform for waveform in waveforms], axis=0)
out = np.concatenate((out, np.zeros((1,))), axis=0)
time = np.linspace(0, out.size, out.size) / self._sampling_rate
return {"t": time, "y": out}
[docs]
@staticmethod
def build_random_timestamps(count: int, min_gap: float = 0.002, max_gap: float = 0.01) -> list:
"""Function for building random and sorted timestamps for generating waveforms
:param count: Number of timestamps to generate
:param min_gap: Minimum gap between timestamps [s]
:param max_gap: Maximum gap between timestamps [s]
"""
values = []
for _ in range(count):
gap = np.random.uniform(min_gap, max_gap)
if len(values):
values.append(values[-1] + gap)
else:
values.append(gap)
return values
[docs]
def create_design(
self,
waveform: str,
num_params: int,
is_signed: bool,
target: str,
bitwidth: int,
id: str,
path2save: Path,
use_bram: bool = False,
do_opt: bool = False,
) -> list[int]:
"""Creating the hardware design for executing on specific target
:param waveform: String with waveform type for anodic phase
:param num_params: Number of params for the waveform
:param is_signed: Boolean indicating whether to use signed or not
:param target: String with target name ["mcu", "pc", "fpga"]
:param bitwidth: Integer with total bitwidth
:param id: String with unique identifier of device (appended to the name)
:param path2save: Path to save the hardware files
:param use_bram: Boolean indicating whether to use bram or not
:param do_opt: Boolean indicating whether to do opt or not
:return: None
"""
supported_targets = ["mcu", "pc", "fpga"]
if target.lower() not in supported_targets:
raise ValueError(f"Target {target} is not supported: only {supported_targets}")
if target.lower() in ["mcu", "pc"]:
if use_bram:
raise AttributeError("BRAM is not supported for MCU and PC")
return self._create_design_c(
waveform=waveform,
num_params=num_params,
is_signed=is_signed,
id=id,
bitwidth=bitwidth,
path2save=path2save,
do_opt=do_opt,
)
else:
return self._create_design_verilog(
waveform=waveform,
num_params=num_params,
is_signed=is_signed,
id=id,
bitwidth=bitwidth,
path2save=path2save,
do_opt=do_opt,
use_bram=use_bram,
)
def _create_design_verilog(
self,
waveform: str,
num_params: int,
is_signed: bool,
id: str,
bitwidth: int,
path2save: Path,
do_opt: bool,
use_bram: bool,
) -> list[int]:
self._logger.debug("Creating Verilog design for Waveform Player")
path2save.mkdir(parents=True, exist_ok=True)
conv = int_converter(
total_bits=bitwidth if not do_opt else bitwidth - 1, signed=is_signed if not do_opt else False
)
wvf = hw_utils.prepare_waveform(
waveform=waveform,
bitwidth=bitwidth,
num_params=num_params,
do_opt=do_opt,
is_signed=is_signed,
)
if use_bram:
path2mem = path2save / "data.mem"
if do_opt:
wvf.reverse()
write_mem_file(path=path2mem, data=wvf, bitwidth=bitwidth if do_opt else bitwidth - 1)
verilog_type = "waveform_ram_full" if not do_opt else "waveform_ram_opt"
params = {
"BITWIDTH": bitwidth,
"WAIT_WIDTH": bitwidth,
"RAMWIDTH": len(wvf),
"PATH2MEM": translate_path_to_int(path2mem),
}
else:
verilog_type = "waveform_lut_full" if not do_opt else "waveform_lut_opt"
params = {
"BITWIDTH": bitwidth,
"WAIT_WIDTH": bitwidth,
"LUTWIDTH": len(wvf),
"LUT_DATA": conv.integer_to_hex_string_array_verilog(wvf),
}
if do_opt:
params.update({"SIGNED_OUT": 1 if is_signed else 0})
self._logger.debug(f"Building Verilog design at {path2save.as_posix()}")
hw_utils.load_and_plugin(
type=verilog_type,
id=id,
params=params,
packages=["waveform"],
path2save=path2save,
use_bram=use_bram,
)
return wvf
def _create_design_c(
self,
waveform: str,
num_params: int,
is_signed: bool,
id: str,
bitwidth: int,
path2save: Path,
do_opt: bool,
) -> list[int]:
self._logger.debug("Creating C design for Waveform Player")
# --- Step #1: Generating the waveform
datatype_data_ext = get_embedded_datatype(bitwidth=bitwidth, signed=is_signed)
bitwidth_mcu = int(datatype_data_ext.split("int")[-1].split("_")[0])
wvf = hw_utils.prepare_waveform(
waveform=waveform,
bitwidth=bitwidth_mcu,
num_params=num_params,
do_opt=do_opt,
is_signed=is_signed,
)
# --- Step #2: Generating the values for parameter dict
params = {
"datetime_created": datetime.now().strftime("%m/%d/%Y, %H:%M:%S"),
"path2include": "src",
"template_name": "waveform_lut_template.h",
"device_id": str(id.upper()),
"datatype_cnt": get_embedded_datatype(bitwidth=len(wvf), signed=False),
"datatype_int": get_embedded_datatype(bitwidth, signed=is_signed),
"num_lutsine": str(len(wvf)),
"lut_offset": str(0 if not do_opt else (0 if is_signed else (2 ** (bitwidth_mcu - 1)))),
"lut_data": ", ".join(map(str, wvf)),
}
# --- Step #3: Replace string parameters with real values
path2template = Path(hw_utils.__file__).parent / "c"
self._logger.debug(f"Building C design at {path2save.as_posix()}")
template = generate_waveform_lut_template(do_opt)
generate_c_files(
path2save=path2save,
template_name=params["template_name"],
file_name="waveform_lut",
module_id=id.lower(),
proto_file=replace_variables_with_parameters(template["head"], params),
impl_file=replace_variables_with_parameters(template["func"], params),
path2template=path2template,
)
return wvf