Source code for elasticai.creator.nn.fixed_point.lstm.layer

from collections import OrderedDict
from typing import cast

import torch

from elasticai.creator.base_modules.lstm import LSTM
from elasticai.creator.base_modules.lstm_cell import LSTMCell
from elasticai.creator.nn.design_creator_module import DesignCreatorModule
from elasticai.creator.nn.fixed_point.hard_sigmoid import HardSigmoid
from elasticai.creator.nn.fixed_point.hard_tanh import HardTanh
from elasticai.creator.nn.fixed_point.lstm.design.fp_lstm_cell import FPLSTMCell
from elasticai.creator.nn.fixed_point.lstm.design.lstm import LSTMNetworkDesign
from elasticai.creator.nn.fixed_point.two_complement_fixed_point_config import (
    FixedPointConfig,
)
from elasticai.creator.vhdl.design.design import Design

from ..math_operations import MathOperations
from .design.testbench import LSTMTestBench


[docs] class LSTMNetwork(DesignCreatorModule): def __init__(self, layers: list[torch.nn.Module]): super().__init__() self.lstm = layers[0] self.layer_names = [f"fp_linear_{i}" for i in range(len(layers[1:]))] if len(self.layer_names) > 1: raise NotImplementedError self.layers = torch.nn.Sequential( OrderedDict( {name: layer for name, layer in zip(self.layer_names, layers[1:])} ) )
[docs] def create_design(self, name: str) -> LSTMNetworkDesign: first_lstm = cast(FixedPointLSTMWithHardActivations, self.lstm) total_bits = first_lstm.fixed_point_config.total_bits frac_bits = first_lstm.fixed_point_config.frac_bits hidden_size = first_lstm.hidden_size input_size = first_lstm.input_size return LSTMNetworkDesign( lstm=first_lstm.create_design(), linear_layer=self.layers[0].create_design(self.layer_names[0]), total_bits=total_bits, frac_bits=frac_bits, hidden_size=hidden_size, input_size=input_size, )
[docs] def create_testbench(self, test_bench_name, uut: Design) -> LSTMTestBench: return LSTMTestBench(test_bench_name, uut)
[docs] def forward(self, x): x, (_, _) = self.lstm(x) x = x[:, -1] return self.layers(x)
[docs] class FixedPointLSTMWithHardActivations(DesignCreatorModule, LSTM): """ Use only together with the above `LSTMNetwork`. There is no single hw design corresponding to this sw layer. Instead, the design of the `LSTMNetwork` handles most of the tasks, that are performed by `FixedPointLSTMWithHardActivations` """ def __init__( self, total_bits: int, frac_bits: int, input_size: int, hidden_size: int, bias: bool, ) -> None: config = FixedPointConfig(total_bits=total_bits, frac_bits=frac_bits) class LayerFactory: def lstm(self, input_size: int, hidden_size: int, bias: bool) -> LSTMCell: def activation(constructor): def wrapped_constructor(): return constructor(total_bits=total_bits, frac_bits=frac_bits) return wrapped_constructor return LSTMCell( operations=MathOperations(config=config), sigmoid_factory=activation(HardSigmoid), tanh_factory=activation(HardTanh), input_size=input_size, hidden_size=hidden_size, bias=bias, ) super().__init__( input_size=input_size, hidden_size=hidden_size, bias=bias, batch_first=True, layers=LayerFactory(), ) self._config = config @property def fixed_point_config(self) -> FixedPointConfig: return self._config
[docs] def create_design(self, name: str = "lstm_cell") -> Design: def float_to_signed_int(value: float | list) -> int | list: if isinstance(value, list): return list(map(float_to_signed_int, value)) return self._config.as_integer(value) def cast_weights(x): return cast(list[list[list[int]]], x) def cast_bias(x): return cast(list[list[int]], x) return FPLSTMCell( name=name, hardtanh=self.cell.tanh.create_design(f"{name}_hardtanh"), hardsigmoid=self.cell.sigmoid.create_design(f"{name}_hardsigmoid"), total_bits=self._config.total_bits, frac_bits=self._config.frac_bits, w_ih=cast_weights(float_to_signed_int(self.cell.linear_ih.weight.tolist())), w_hh=cast_weights(float_to_signed_int(self.cell.linear_hh.weight.tolist())), b_ih=cast_bias(float_to_signed_int(self.cell.linear_ih.bias.tolist())), b_hh=cast_bias(float_to_signed_int(self.cell.linear_hh.bias.tolist())), )