Source code for elasticai.creator.ir2torch.default_handlers

from typing import cast

from torch import nn

from elasticai.creator.torch2ir import Implementation

handlers = []


def _register(fn):
    handlers.append(fn)
    return fn


[docs] @_register def linear(impl: Implementation) -> nn.Module: return nn.Linear( in_features=cast(int, impl.data["in_features"]), out_features=cast(int, impl.data["out_features"]), bias=cast(bool, impl.data["bias"]), )
[docs] @_register def relu(impl: Implementation) -> nn.Module: return nn.ReLU()
[docs] @_register def conv1d(impl: Implementation) -> nn.Conv1d: keywords = ( "in_channels", "out_channels", "bias", "groups", "kernel_size", "stride", ) kwargs = dict(((k, impl.data[k]) for k in keywords)) return nn.Conv1d(**kwargs) # type: ignore
[docs] @_register def sigmoid(impl: Implementation) -> nn.Sigmoid: return nn.Sigmoid()
[docs] @_register def flatten(imp: Implementation) -> nn.Flatten: return nn.Flatten()