Source code for elasticai.creator_plugins.quantized_grads.base_modules.batchnorm2d
from typing import Any
from torch import Tensor
from torch.nn import BatchNorm2d as TorchBatchNorm2d
from torch.nn import Module
from torch.nn.utils import parametrize as P
[docs]
class BatchNorm2d(TorchBatchNorm2d):
"""A BatchNorm2d.
The output of the batchnorm is fake quantized. The weights and bias are fake quantized during initialization.
Make sure that math_ops is a module where all needed tensors are part of it,
so they can be moved to the same device.
Make sure that weight_quantization and bias_quantization are modules that implement the forward function.
If you want to quantize during initialization or only apply quantized updates make sure to use a quantized optimizer
and implement the right_inverse method for your module.
"""
def __init__(
self,
math_ops: Module,
weight_quantization: Module,
bias_quantization: Module,
num_features: int,
eps: float = 1e-5,
momentum: float = 0.1,
affine: bool = True,
track_running_stats: bool = True,
device: Any = None,
dtype: Any = None,
) -> None:
super().__init__(
num_features=num_features,
eps=eps,
momentum=momentum,
affine=affine,
track_running_stats=track_running_stats,
device=device,
dtype=dtype,
)
P.register_parametrization(self, "weight", weight_quantization)
P.register_parametrization(self, "bias", bias_quantization)
self.add_module("math_ops", math_ops)
[docs]
def forward(self, x: Tensor) -> Tensor:
return self.math_ops(super().forward(x))