Source code for elasticai.creator.nn.fixed_point.precomputed.identity_step_function

from typing import Any

import torch


[docs] class IdentityStepFunction(torch.autograd.Function):
[docs] @staticmethod def forward(ctx: Any, *args: Any, **kwargs: Any) -> torch.Tensor: if len(args) != 2: raise TypeError( "apply() takes exactly two arguments " "(x: torch.Tensor, step_lut: torch.Tensor)" ) x: torch.Tensor = args[0] step_lut: torch.Tensor = args[1] steps = len(step_lut) if steps < 2: raise ValueError( f"Number of steps cannot be less than or equal to 1 (steps == {steps})." ) x = x.to(torch.float32) x = x.clamp(min=step_lut.min(), max=step_lut.max()) for step_idx in range(1, len(step_lut)): prev_step, curr_step = step_lut[step_idx - 1], step_lut[step_idx] x[(x > prev_step) & (x <= curr_step)] = curr_step return x
[docs] @staticmethod def backward(ctx: Any, *grad_outputs: Any) -> Any: return *grad_outputs, None