elasticai.creator.ir2torch.ir2torch#

Module Contents#

Classes#

Functions#

API#

class elasticai.creator.ir2torch.ir2torch.Ir2Torch[source]#

Bases: elasticai.creator.ir.LoweringPass[elasticai.creator.torch2ir.Implementation, torch.nn.Module]

convert(ir: collections.abc.Iterable[elasticai.creator.torch2ir.Implementation], state_dict: dict[str, Any] | None = None) torch.nn.Module[source]#

Rebuild the original pytorch model from a given IR.

Implemenation names containing dots will result in the corresponding modules sorted into a corresponding object hierarchy. E.g., for the implementation name 'top.linear' we will create a pytorch container module under the name 'top' and add the linear layer to it under the name 'linear'. Note that this is an implementation detail of Ir2Torch and not a semantic meaning assigned to the '.' character.

Param:

ir: You need to make sure that Ir2Torch has a type handler for each implementation in ir

Param:

state_dict: You can optionally pass a state dict. This should be a state dict created from the original model via nn.Module.state_dict. As the Torch2Ir stage got rid of all duplicate submodules, we will strip all unknown keys from the state_dict and then load it.

register_type_handlers(handlers: collections.abc.Iterable[collections.abc.Callable[[elasticai.creator.torch2ir.Implementation], torch.nn.Module]]) None[source]#
elasticai.creator.ir2torch.ir2torch.get_default_converter() elasticai.creator.ir2torch.ir2torch.Ir2Torch[source]#