Source code for elasticai.creator.nn.quantized_grads.fixed_point.module_quantization
import torch
from torch import Tensor
from torch.nn import Module
from .autograd import (
QuantizeForwHTEAutograd,
QuantizeForwHTEBackwHTEAutograd,
QuantizeForwHTEBackwStochasticAutograd,
QuantizeForwStochasticAutograd,
QuantizeForwStochasticBackwStochasticAutograd,
)
from .two_complement_fixed_point_config import FixedPointConfigV2
class _ForwModule(Module):
def __init__(self, forward_conf: FixedPointConfigV2): ...
class _ForwBackwModule(Module):
def __init__(
self, forward_conf: FixedPointConfigV2, backward_conf: FixedPointConfigV2
): ...
[docs]
def get_fxp_forw_quantization_module(
autograd_func: torch.autograd.Function,
) -> type[_ForwModule]:
class FxPForwQuantizationModule(Module):
def __init__(self, forward_conf: FixedPointConfigV2) -> None:
super().__init__()
self.autograd_func = autograd_func
self.register_buffer(
"forw_resolution_per_int", forward_conf.resolution_per_int
)
self.register_buffer(
"forw_minimum_as_rational", forward_conf.minimum_as_rational
)
self.register_buffer(
"forw_maximum_as_rational", forward_conf.maximum_as_rational
)
def forward(self, x: Tensor) -> Tensor:
return self.autograd_func.apply(
x,
self.forw_resolution_per_int,
self.forw_minimum_as_rational,
self.forw_maximum_as_rational,
)
return FxPForwQuantizationModule
[docs]
def get_fxp_forwbackw_quantization_module(
autograd_func: torch.autograd.Function,
) -> type[_ForwBackwModule]:
class FxPForwBackwQuantizationModule(Module):
def __init__(
self, forward_conf: FixedPointConfigV2, backward_conf: FixedPointConfigV2
) -> None:
super().__init__()
self.autograd_func = autograd_func
self.register_buffer(
"forw_resolution_per_int", forward_conf.resolution_per_int
)
self.register_buffer(
"forw_minimum_as_rational", forward_conf.minimum_as_rational
)
self.register_buffer(
"forw_maximum_as_rational", forward_conf.maximum_as_rational
)
self.register_buffer(
"backw_resolution_per_int", backward_conf.resolution_per_int
)
self.register_buffer(
"backw_minimum_as_rational", backward_conf.minimum_as_rational
)
self.register_buffer(
"backw_maximum_as_rational", backward_conf.maximum_as_rational
)
self._kwargs = {
"forw_resolution_per_int": self.forw_resolution_per_int,
"forw_minimum_as_rational": self.forw_minimum_as_rational,
"forw_maximum_as_rational": self.forw_maximum_as_rational,
"backw_resolution_per_int": self.backw_resolution_per_int,
"backw_minimum_as_rational": self.backw_minimum_as_rational,
"backw_maximum_as_rational": self.backw_maximum_as_rational,
}
def forward(self, x: Tensor) -> Tensor:
return self.autograd_func.apply(
x,
self.forw_resolution_per_int,
self.forw_minimum_as_rational,
self.forw_maximum_as_rational,
self.backw_resolution_per_int,
self.backw_minimum_as_rational,
self.backw_maximum_as_rational,
)
return FxPForwBackwQuantizationModule
QuantizeForwHTE: type[_ForwModule] = get_fxp_forw_quantization_module(
QuantizeForwHTEAutograd
)
QuantizeForwStochastic: type[_ForwModule] = get_fxp_forw_quantization_module(
QuantizeForwStochasticAutograd
)
QuantizeForwHTEBackwHTE: type[_ForwBackwModule] = get_fxp_forwbackw_quantization_module(
QuantizeForwHTEBackwHTEAutograd
)
QuantizeForwHTEBackwStochastic: type[_ForwBackwModule] = (
get_fxp_forwbackw_quantization_module(QuantizeForwHTEBackwStochasticAutograd)
)
QuantizeForwStochasticBackwStochastic: type[_ForwBackwModule] = (
get_fxp_forwbackw_quantization_module(QuantizeForwStochasticBackwStochasticAutograd)
)