Source code for elasticai.creator.nn.fixed_point.precomputed.precomputed_module
from typing import cast
import torch
from elasticai.creator.nn.design_creator_module import DesignCreatorModule
from elasticai.creator.nn.fixed_point.math_operations import MathOperations
from elasticai.creator.nn.fixed_point.two_complement_fixed_point_config import (
FixedPointConfig,
)
from elasticai.creator.vhdl.shared_designs.precomputed_scalar_function import (
PrecomputedScalarFunction,
)
from .identity_step_function import IdentityStepFunction
[docs]
class PrecomputedModule(DesignCreatorModule):
_xoffset: float
_lut_input: torch.nn.Buffer
def __init__(
self,
base_module: torch.nn.Module,
total_bits: int,
frac_bits: int,
num_steps: int,
sampling_intervall: tuple[float, float],
) -> None:
super().__init__()
self._base_module = base_module
self._config = FixedPointConfig(total_bits=total_bits, frac_bits=frac_bits)
self._operations = MathOperations(self._config)
self._build_lut(in_range=sampling_intervall, lut_size=num_steps)
def _build_lut(self, in_range: tuple[float, float], lut_size: int) -> None:
range_neg = (
self._config.minimum_as_rational
if abs(in_range[0]) == float("inf")
else in_range[0]
)
range_pos = (
self._config.maximum_as_rational
if abs(in_range[1]) == float("inf")
else in_range[1]
)
lut_num_steps = (
2**self._config.total_bits
if lut_size > 2**self._config.total_bits
else lut_size
)
self._lut_input = torch.nn.Buffer(
self._operations.round(
torch.linspace(start=range_neg, end=range_pos, steps=lut_num_steps)
),
persistent=False,
)
lut_diff = (
torch.abs(torch.diff(self._lut_input))
/ self._config.minimum_step_as_rational
/ 2
)
self._xoffset = (
float((lut_diff.max() + lut_diff.min()) / 2)
* self._config.minimum_step_as_rational
)
[docs]
def get_lut_integer(self) -> tuple[list[int], list[int]]:
return (
list(map(self._config.cut_as_integer, self._lut_input.tolist())),
[self._forward_nograd(val) for val in self._lut_input.tolist()],
)
def _stepped_inputs(self, x: torch.Tensor) -> torch.Tensor:
return cast(torch.Tensor, IdentityStepFunction.apply(x, self._lut_input))
def _forward_nograd(self, x: int) -> int:
fxp_input = self._config.as_rational(x)
if not isinstance(fxp_input, float):
raise ValueError()
with torch.no_grad():
output = self.forward(torch.tensor(fxp_input).clone().detach())
return self._config.round_to_integer(float(output.item()))
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self._stepped_inputs(x)
y = self._base_module(x - self._xoffset)
return self._operations.round(y)
[docs]
def create_design(self, name: str) -> PrecomputedScalarFunction:
q_input = self.get_lut_integer()[0]
return PrecomputedScalarFunction(
name=name,
input_width=self._config.total_bits,
output_width=self._config.total_bits,
inputs=q_input,
function=self._forward_nograd,
)