Source code for elasticai.creator.nn.fixed_point.linear.layer.linear
from typing import Any, cast
from elasticai.creator.arithmetic import (
FxpArithmetic,
FxpParams,
)
from elasticai.creator.base_modules.linear import Linear as LinearBase
from elasticai.creator.nn.design_creator_module import DesignCreatorModule
from elasticai.creator.nn.fixed_point.linear.design import LinearDesign
from elasticai.creator.nn.fixed_point.math_operations import MathOperations
[docs]
class Linear(DesignCreatorModule, LinearBase):
def __init__(
self,
in_features: int,
out_features: int,
total_bits: int,
frac_bits: int,
bias: bool = True,
device: Any = None,
) -> None:
self._params = FxpParams(
total_bits=total_bits, frac_bits=frac_bits, signed=True
)
self._config = FxpArithmetic(self._params)
super().__init__(
in_features=in_features,
out_features=out_features,
operations=MathOperations(config=self._config),
bias=bias,
device=device,
)
[docs]
def get_params(self) -> tuple[list[list[float]], list[float]]:
bias = [0] * self.out_features if self.bias is None else self.bias.tolist()
weights = self.weight.tolist()
return weights, bias
[docs]
def get_params_quant(self) -> tuple[list[list[int]], list[int]]:
weights, bias = self.get_params()
q_weights = cast(list[list[int]], self._config.cut_as_integer(weights))
q_bias = cast(list[int], self._config.cut_as_integer(bias))
return q_weights, q_bias
[docs]
def create_design(self, name: str) -> LinearDesign:
q_weights, q_bias = self.get_params_quant()
return LinearDesign(
frac_bits=self._config.frac_bits,
total_bits=self._config.total_bits,
in_feature_num=self.in_features,
out_feature_num=self.out_features,
weights=q_weights,
bias=q_bias,
name=name,
)