Source code for elasticai.creator.base_modules.lstm
from typing import Optional, Protocol
import torch
from .lstm_cell import LSTMCell
[docs]
class LayerFactory(Protocol):
[docs]
def lstm(self, input_size: int, hidden_size: int, bias: bool) -> LSTMCell: ...
[docs]
class LSTM(torch.nn.Module):
def __init__(
self,
input_size: int,
hidden_size,
bias: bool,
batch_first: bool,
layers: LayerFactory,
) -> None:
super().__init__()
self.cell = layers.lstm(
input_size=input_size, hidden_size=hidden_size, bias=bias
)
self.batch_first = batch_first
@property
def hidden_size(self) -> int:
return self.cell.hidden_size
@property
def input_size(self) -> int:
return self.cell.input_size
[docs]
def forward(
self,
x: torch.Tensor,
state: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
batched = x.dim() == 3
if batched and self.batch_first:
x = torch.stack(torch.unbind(x), dim=1)
if state is not None:
state = state[0].squeeze(0), state[1].squeeze(0)
inputs = torch.unbind(x, dim=0)
outputs = []
for i in range(len(inputs)):
hidden_state, cell_state = self.cell(inputs[i], state)
state = (hidden_state, cell_state)
outputs.append(hidden_state)
if state is None:
raise RuntimeError("Number of samples must be larger than 0.")
result = torch.stack(outputs, dim=1 if batched and self.batch_first else 0)
# TODO: check whether unsqueeze dimension is actually consistent with self.batch_first being true or false
hidden_state, cell_state = state[0].unsqueeze(0), state[1].unsqueeze(0)
return result, (hidden_state, cell_state)