Source code for elasticai.creator.nn.binary.math_operations

from typing import cast

import torch

from elasticai.creator.base_modules.conv1d import MathOperations as Conv1dOps
from elasticai.creator.base_modules.linear import MathOperations as LinearOps
from elasticai.creator.base_modules.lstm_cell import MathOperations as LSTMOps

from .binary_quantization_function import Binarize


[docs] class MathOperations(LinearOps, Conv1dOps, LSTMOps):
[docs] def quantize(self, a: torch.Tensor) -> torch.Tensor: return cast(torch.Tensor, Binarize.apply(a))
[docs] def add(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: return self.quantize(a + b)
[docs] def mul(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: return self.quantize(a * b)
[docs] def matmul(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: return self.quantize(torch.matmul(a, b))