Source code for elasticai.creator.nn.quantized_grads.quantized_optim
from typing import Callable
import torch
from torch import Tensor, optim
from torch.nn import Module
from elasticai.creator.nn.quantized_grads.base_modules import parametrized_modules
class _QOptim(optim.Optimizer):
def __init__(
self,
model,
*args,
buffer_quantizations: dict[str, Callable[[Tensor], Tensor]] = None,
**kwargs,
): ...
[docs]
def get_quantized_optimizer(optimizer: type[optim.Optimizer]) -> type[_QOptim]:
class QuantizedOptim(optimizer):
def __init__(
self,
model: Module,
*args,
buffer_quantizations: dict[str, Module] = None,
**kwargs,
):
super().__init__(*args, **kwargs)
self._model = model
self.register_step_post_hook(_ensure_tensors_quantization_after_update)
if buffer_quantizations is None:
self.buffer_quantization = {}
else:
self.buffer_quantization: dict[str, Callable[[Tensor], Tensor]] = (
buffer_quantizations
)
self.register_step_post_hook(_quantization_for_buffers_in_optimizer)
def step(self, closure=None):
super().step(closure)
def _ensure_tensors_quantization_after_update(
optimizer: QuantizedOptim, *args, **kwargs
):
with torch.no_grad():
for m in optimizer._model.modules():
if m.__class__.__name__ in parametrized_modules:
for p in m.parametrizations:
m_parametrization_param = m.parametrizations[p]
m_param = getattr(m, p)
m_param.data = m_parametrization_param[0].right_inverse(m_param)
def _quantization_for_buffers_in_optimizer(
optimizer: QuantizedOptim, *args, **kwargs
):
for group in optimizer.param_groups:
for p in group["params"]:
state = optimizer.state[p]
for key in state.keys():
if key in optimizer.buffer_quantization.keys():
if (
state[key].device
!= optimizer.buffer_quantization[key].device
):
optimizer.buffer_quantization[key].to(state[key].device)
state[key] = optimizer.buffer_quantization[key](state[key])
return QuantizedOptim