Source code for elasticai.creator.nn.quantized_grads.fixed_point.param_quantization

from typing import Callable

from torch import Tensor
from torch.nn import Module

from .quantize_to_fixed_point import quantize_to_fxp_hte, quantize_to_fxp_stochastic
from .two_complement_fixed_point_config import FixedPointConfigV2


[docs] def get_quantize_to_fixed_point( func: Callable[[Tensor, Tensor, Tensor, Tensor], Tensor], ) -> tuple[type[Module], type[Module]]: class QuantizeToFixedPoint(Module): def __init__(self, config: FixedPointConfigV2): super().__init__() self.register_buffer("minimum_as_rational", config.minimum_as_rational) self.register_buffer("maximum_as_rational", config.maximum_as_rational) self.register_buffer("resolution_per_int", config.resolution_per_int) @staticmethod def forward(x: Tensor) -> Tensor: return x def right_inverse(self, x: Tensor) -> Tensor: return func( x, self.resolution_per_int, self.minimum_as_rational, self.maximum_as_rational, ) class QuantizeToFixedPointSTE(Module): def __init__(self, config: FixedPointConfigV2): super().__init__() self.register_buffer("minimum_as_rational", config.minimum_as_rational) self.register_buffer("maximum_as_rational", config.maximum_as_rational) self.register_buffer("resolution_per_int", config.resolution_per_int) def forward(self, x: Tensor) -> Tensor: return func( x, self.resolution_per_int, self.minimum_as_rational, self.maximum_as_rational, ) def right_inverse(self, x: Tensor) -> Tensor: return x return QuantizeToFixedPoint, QuantizeToFixedPointSTE
(QuantizeParamToFixedPointHTE, QuantizeParamSTEToFixedPointHTE) = ( get_quantize_to_fixed_point(quantize_to_fxp_hte) ) (QuantizeParamToFixedPointStochastic, QuantizeParamSTEToFixedPointStochastic) = ( get_quantize_to_fixed_point(quantize_to_fxp_stochastic) )
[docs] class QuantizeTensorToFixedPointHTE(QuantizeParamSTEToFixedPointHTE): """ This Modules can be used for Tensor quantization """
[docs] class QuantizeTensorToFixedPointStochastic(QuantizeParamSTEToFixedPointStochastic): """ This Modules can be used for Tensor quantization """