import warnings
from functools import partial
from elasticai.creator.file_generation.savable import Path
from elasticai.creator.file_generation.template import (
InProjectTemplate,
module_to_package,
)
from elasticai.creator.vhdl.code_generation.code_abstractions import (
to_vhdl_binary_string,
)
from elasticai.creator.vhdl.design.ports import Port
[docs]
class Skeleton:
def __init__(
self,
x_num_values: int,
y_num_values: int,
network_name: str,
port: Port,
id: list[int] | int,
skeleton_version: str = "v1",
):
self.name = "skeleton"
self._network_name = network_name
self._port = port
self._x_num_values = str(x_num_values)
self._y_num_values = str(y_num_values)
if isinstance(id, int):
id = [id]
self._id = id
if skeleton_version == "v1":
warnings.warn(
(
"Skeleton V1 might be deprecated in the future. Consider using"
" Skeleton V2 instead."
),
FutureWarning,
)
self._template_file_name = "network_skeleton.tpl.vhd"
if len(id) != 1:
raise Exception(
f"should give an id of 1 byte. Actual length is {len(id)}"
)
if x_num_values > 100:
raise Exception(
"Not more than 100 input values allowed. Actual num of inputs"
f" {x_num_values} ."
)
if y_num_values > 100:
raise Exception(
"Not more than 100 input values allowed. Actual num of inputs"
f" {x_num_values} ."
)
elif skeleton_version == "v2":
self._template_file_name = "network_skeleton_v2.tpl.vhd"
if len(id) != 16:
raise Exception(
f"should give an id of 16 byte. Actual length is {len(id)}"
)
if x_num_values > 19983:
raise Exception(
"Not more than 19983 input values allowed. Actual num of inputs"
f" {x_num_values} ."
)
if y_num_values > 19983:
raise Exception(
"Not more than 19983 input values allowed. Actual num of inputs"
f" {x_num_values} ."
)
else:
raise Exception(f"Skeleton version {skeleton_version} does not exist")
if port["x"].width > 8:
raise Exception(
"port x width should not be bigger than 8. You assigned "
f" {port['x'].width=}"
)
if port["y"].width > 8:
raise Exception(
"port x width should not be bigger than 8. You assigned "
f" {port['y'].width=}"
)
[docs]
def save_to(self, destination: Path):
template = InProjectTemplate(
package=module_to_package(self.__module__),
file_name=self._template_file_name,
parameters=dict(
name=self.name,
network_name=self._network_name,
data_width_in=str(self._port["x"].width),
x_addr_width=str(self._port["x_address"].width),
x_num_values=self._x_num_values,
y_num_values=self._y_num_values,
data_width_out=str(self._port["y"].width),
y_addr_width=str(self._port["y_address"].width),
id=", ".join(
map(partial(to_vhdl_binary_string, number_of_bits=8), self._id)
),
),
)
file = destination.as_file(".vhd")
file.write(template)
[docs]
class LSTMSkeleton:
def __init__(self, network_name: str):
self.name = "skeleton"
self._network_name = network_name
[docs]
def save_to(self, destination: Path):
template = InProjectTemplate(
package=module_to_package(self.__module__),
file_name="lstm_network_skeleton.tpl.vhd",
parameters=dict(name=self.name, network_name=self._network_name),
)
file = destination.as_file(".vhd")
file.write(template)
[docs]
class EchoSkeletonV2:
def __init__(self, num_values: int, bitwidth: int):
self._num_values = num_values
self._bitwidth = bitwidth
self._name = "skeleton"
if bitwidth > 8:
raise Exception(
"Not more than 8 bit supported by middleware. You assigned"
f" {bitwidth} bits"
)
self._template_file_name = "network_skeleton_v2_echo.tpl.vhd"
self._id = [50, 52, 48, 56, 50, 51, 69, 67, 72, 79, 83, 69, 82, 86, 69, 82]
[docs]
def save_to(self, destination: Path):
template = InProjectTemplate(
package=module_to_package(self.__module__),
file_name=self._template_file_name,
parameters=dict(
name=self._name,
data_width=str(self._bitwidth),
num_values=str(self._num_values),
id=", ".join(
map(partial(to_vhdl_binary_string, number_of_bits=8), self._id)
),
),
)
file = destination.as_file(".vhd")
file.write(template)