Source code for elasticai.creator.nn.fixed_point.round_to_fixed_point
from typing import Any
import torch
from elasticai.creator.nn.fixed_point.two_complement_fixed_point_config import (
FixedPointConfig,
)
[docs]
class RoundToFixedPoint(torch.autograd.Function):
[docs]
@staticmethod
def forward(ctx: Any, *args: Any, **kwargs: Any) -> torch.Tensor:
if len(args) != 2:
raise TypeError(
"apply() takes exactly two arguments "
"(x: torch.Tensor, config: FixedPointConfig)"
)
x: torch.Tensor = args[0]
config: FixedPointConfig = args[1]
fxp_ints = config.as_integer(x)
out_of_bounds = fxp_ints[config.integer_out_of_bounds(fxp_ints)]
if torch.any(out_of_bounds):
raise ValueError("Cannot quantize tensor. Values out of bounds.")
return config.as_rational(fxp_ints)
[docs]
@staticmethod
def backward(ctx: Any, *grad_outputs: Any) -> Any:
return *grad_outputs, None