Source code for elasticai.creator.base_modules.linear

from typing import Any, Protocol

import torch

from elasticai.creator.base_modules.math_operations import Add, MatMul, Quantize


[docs] class MathOperations(Quantize, Add, MatMul, Protocol): ...
[docs] class Linear(torch.nn.Linear): def __init__( self, in_features: int, out_features: int, operations: MathOperations, bias: bool, device: Any = None, dtype: Any = None, ) -> None: super().__init__(in_features, out_features, bias, device, dtype) self._operations = operations
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: weight = self._operations.quantize(self.weight) if self.bias is not None: bias = self._operations.quantize(self.bias) return self._operations.add(self._operations.matmul(x, weight.T), bias) return self._operations.matmul(x, weight.T)