Source code for elasticai.creator.nn.fixed_point.precomputed.precomputed_module
from typing import cast
import torch
from elasticai.creator.nn.design_creator_module import DesignCreatorModule
from elasticai.creator.nn.fixed_point.math_operations import MathOperations
from elasticai.creator.nn.fixed_point.two_complement_fixed_point_config import (
FixedPointConfig,
)
from elasticai.creator.vhdl.shared_designs.precomputed_scalar_function import (
PrecomputedScalarFunction,
)
from .identity_step_function import IdentityStepFunction
[docs]
class PrecomputedModule(DesignCreatorModule):
def __init__(
self,
base_module: torch.nn.Module,
total_bits: int,
frac_bits: int,
num_steps: int,
sampling_intervall: tuple[float, float],
) -> None:
super().__init__()
self._base_module = base_module
self._config = FixedPointConfig(total_bits=total_bits, frac_bits=frac_bits)
self._operations = MathOperations(self._config)
self._step_lut = torch.nn.Parameter(
torch.linspace(*sampling_intervall, num_steps),
requires_grad=False,
)
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self._stepped_inputs(x)
outputs = self._base_module(x)
return self._operations.quantize(outputs)
[docs]
def create_design(self, name: str) -> PrecomputedScalarFunction:
quantized_inputs = list(map(self._config.as_integer, self._step_lut.tolist()))
return PrecomputedScalarFunction(
name=name,
input_width=self._config.total_bits,
output_width=self._config.total_bits,
inputs=quantized_inputs,
function=self._quantized_inference,
)
def _stepped_inputs(self, x: torch.Tensor) -> torch.Tensor:
step_inputs = cast(torch.Tensor, IdentityStepFunction.apply(x, self._step_lut))
return self._operations.quantize(step_inputs)
def _quantized_inference(self, x: int) -> int:
fxp_input = self._config.as_rational(x)
with torch.no_grad():
output = self.cpu()(torch.tensor(fxp_input))
return self._config.as_integer(float(output.item()))