Source code for elasticai.creator.nn.fixed_point.lstm.design.lstm
from functools import partial
from elasticai.creator.nn.fixed_point.linear.layer import Linear as FPLinear1d
from ._common_imports import (
Design,
InProjectTemplate,
Path,
Port,
Signal,
calculate_address_width,
module_to_package,
std_signals,
)
[docs]
class LSTMNetworkDesign(Design):
def __init__(
self,
lstm: Design,
linear_layer: FPLinear1d,
total_bits: int,
frac_bits: int,
hidden_size: int,
input_size: int,
) -> None:
super().__init__(name="lstm_network")
self._linear_layer = linear_layer
self._lstm = lstm
self.template = InProjectTemplate(
module_to_package(self.__module__),
file_name="lstm_network.tpl.vhd",
parameters=dict(
name=self.name,
lstm_cell_name=self._lstm.name,
linear_name=self._linear_layer.name,
data_width=str(total_bits),
frac_width=str(frac_bits),
hidden_size=str(hidden_size),
input_size=str(input_size),
linear_in_features=str(self._linear_layer.in_feature_num),
linear_out_features=str(self._linear_layer.out_feature_num),
hidden_addr_width=(
f"{calculate_address_width(hidden_size + input_size)}"
),
x_h_addr_width=f"{calculate_address_width(hidden_size + input_size)}",
w_addr_width=(
f"{calculate_address_width((hidden_size + input_size) * hidden_size)}"
),
in_addr_width="4",
),
)
ctrl_signal = partial(Signal, width=0)
self._port = Port(
incoming=[
std_signals.clock(),
std_signals.enable(),
ctrl_signal("x_we"),
Signal("x", width=lstm.port["x_data"].width),
Signal("addr_in", width=lstm.port["h_out_addr"].width),
],
outgoing=[
std_signals.done(),
Signal("d_out", width=lstm.port["h_out_data"].width),
],
)
self._subpath_name = "lstm_network"
@property
def port(self) -> Port:
return self._port
[docs]
def save_to(self, destination: Path) -> None:
self._lstm.save_to(destination)
self._linear_layer.save_to(destination.create_subpath(self._linear_layer.name))
destination.create_subpath(self._subpath_name).as_file(".vhd").write(
self.template
)