Source code for elasticai.creator.nn.binary.binary_quantization_function

from typing import Any

import torch


[docs] class Binarize(torch.autograd.Function):
[docs] @staticmethod def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any: if len(args) == 0: raise TypeError x = args[0] out_of_range = torch.logical_or(torch.gt(x, 1.0), torch.lt(x, -1.0)) ctx.save_for_backward(out_of_range) y = torch.where(x >= 0, 1.0, -1.0) return y
[docs] @staticmethod def backward(ctx: Any, *grad_outputs: Any) -> Any: (out_of_range,) = ctx.saved_tensors return grad_outputs[0] * torch.where(out_of_range, 0.0, 1.0)