elasticai.creator.ir2torch#

Package Contents#

Classes#

Functions#

API#

class elasticai.creator.ir2torch.Ir2Torch[source]#
register(name: str | None, fn: elasticai.creator.ir2torch.ir2torch.TypeHandler) elasticai.creator.ir2torch.ir2torch.TypeHandler[source]#
override(name: str | None, fn: elasticai.creator.ir2torch.ir2torch.TypeHandler) elasticai.creator.ir2torch.ir2torch.TypeHandler[source]#
__call__(ir_root: elasticai.creator.ir2torch.ir2torch.DataGraph, registry: elasticai.creator.ir.ir_v2.Registry[elasticai.creator.ir2torch.ir2torch.DataGraph], state_dict: dict[str, Any] | None = None) torch.nn.Module[source]#

Rebuild the original pytorch model from a given IR.

Implemenation names containing dots will result in the corresponding modules sorted into a corresponding object hierarchy. E.g., for the implementation name 'top.linear' we will create a pytorch container module under the name 'top' and add the linear layer to it under the name 'linear'. Note that this is an implementation detail of Ir2Torch and not a semantic meaning assigned to the '.' character.

Param:

ir: You need to make sure that Ir2Torch has a type handler for each implementation in ir

Param:

state_dict: You can optionally pass a state dict. This should be a state dict created from the original model via nn.Module.state_dict. As the Torch2Ir stage got rid of all duplicate submodules, we will strip all unknown keys from the state_dict and then load it.

elasticai.creator.ir2torch.get_default_converter() elasticai.creator.ir2torch.ir2torch.Ir2Torch[source]#
class elasticai.creator.ir2torch.IrFactory[source]#

Bases: elasticai.creator.ir.ir_v2.IrFactory[elasticai.creator.ir.ir_v2.Node, elasticai.creator.ir.ir_v2.Edge, elasticai.creator.ir.ir_v2.DataGraph]

node(name: str, attributes: elasticai.creator.ir.ir_v2.AttributeMapping = ir.AttributeMapping()) elasticai.creator.ir.ir_v2.Node[source]#
edge(src: str, dst: str, attributes: elasticai.creator.ir.ir_v2.AttributeMapping = ir.AttributeMapping()) elasticai.creator.ir.ir_v2.Edge[source]#
graph(attributes: elasticai.creator.ir.ir_v2.AttributeMapping = ir.AttributeMapping(), /, *, graph: elasticai.creator.ir.ir_v2.DataGraph | None = None) elasticai.creator.ir2torch.ir2torch.DataGraph[source]#