Source code for elasticai.creator.torch2ir.default_handlers

import torch.nn as nn

handlers = []


def _register(fn):
    handlers.append(fn)
    return fn


[docs] @_register def conv1d(module: nn.Conv1d) -> dict: return { "in_channels": module.in_channels, "out_channels": module.out_channels, "kernel_size": module.kernel_size, "stride": module.stride, "padding": module.padding, "dilation": module.dilation, "groups": module.groups, "bias": module.bias is not None, "padding_mode": module.padding_mode, }
[docs] @_register def maxpool1d(module: nn.MaxPool1d) -> dict: return { "kernel_size": module.kernel_size, "stride": module.stride, "padding": module.padding, "dilation": module.dilation, "return_indices": module.return_indices, "ceil_mode": module.ceil_mode, }
[docs] @_register def linear(module: nn.Linear) -> dict: return { "in_features": module.in_features, "out_features": module.out_features, "bias": module.bias is not None, }
[docs] @_register def batchnorm1d(module: nn.BatchNorm1d) -> dict: return { "num_features": module.num_features, "affine": module.affine, }
[docs] @_register def flatten(module: nn.Flatten) -> dict: return {}
[docs] @_register def relu(module: nn.ReLU) -> dict: return {}
[docs] @_register def sigmoid(module: nn.Sigmoid) -> dict: return {}