Source code for elasticai.creator.base_modules.linear

from typing import Any, Protocol

import torch

from elasticai.creator.base_modules.math_operations import Add, MatMul, Quantize


[docs] class MathOperations(Quantize, Add, MatMul, Protocol): ...
[docs] class Linear(torch.nn.Linear): def __init__( self, in_features: int, out_features: int, operations: MathOperations, bias: bool, device: Any = None, dtype: Any = None, ) -> None: super().__init__( in_features=in_features, out_features=out_features, bias=bias, device=device, dtype=dtype, ) self._operations = operations
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: weight = self._operations.quantize(self.weight) if self.bias is None: return self._operations.matmul(x, weight.T) else: return self._operations.add( self._operations.matmul(x, weight.T), self._operations.quantize(self.bias), )