Source code for elasticai.creator.base_modules.conv1d
from typing import Any, Protocol
from torch import Tensor
from torch.nn import Conv1d as _Conv1d
from torch.nn.functional import conv1d
from elasticai.creator.base_modules.math_operations import Quantize
[docs]
class MathOperations(Quantize, Protocol): ...
[docs]
class Conv1d(_Conv1d):
def __init__(
self,
operations: MathOperations,
in_channels: int,
out_channels: int,
kernel_size: int | tuple[int],
stride: int | tuple[int] = 1,
padding: int | tuple[int] | str = 0,
dilation: int | tuple[int] = 1,
groups: int = 1,
bias: bool = True,
device: Any = None,
dtype: Any = None,
) -> None:
super().__init__(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias,
padding_mode="zeros",
device=device,
dtype=dtype,
)
self._operations = operations
[docs]
def forward(self, x: Tensor) -> Tensor:
quantized_weights = self._operations.quantize(self.weight)
quantized_bias = (
self._operations.quantize(self.bias) if self.bias is not None else None
)
convolved = conv1d(
input=x,
weight=quantized_weights,
bias=quantized_bias,
stride=self.stride,
padding=self.padding,
dilation=self.dilation,
groups=self.groups,
)
return self._operations.quantize(convolved)