Source code for elasticai.creator.ir2torch.ir2torch

from collections.abc import Callable, Iterable
from typing import Any

import torch.nn as nn
from torch import fx

from elasticai.creator.ir import LoweringPass
from elasticai.creator.torch2ir import Implementation

from .default_handlers import handlers


[docs] class Ir2Torch(LoweringPass[Implementation, nn.Module]):
[docs] def convert( self, ir: Iterable[Implementation], state_dict: dict[str, Any] | None = None ) -> nn.Module: """Rebuild the original pytorch model from a given IR. Implemenation names containing dots will result in the corresponding modules sorted into a corresponding object hierarchy. E.g., for the implementation name `'top.linear'` we will create a pytorch container module under the name `'top'` and add the linear layer to it under the name `'linear'`. Note that this is an implementation detail of Ir2Torch and not a semantic meaning assigned to the `'.'` character. :param: `ir`: You need to make sure that `Ir2Torch` has a type handler for each implementation in `ir` :param: `state_dict`: You can optionally pass a state dict. This should be a state dict created from the original model via `nn.Module.state_dict`. As the `Torch2Ir` stage got rid of all duplicate submodules, we will strip all unknown keys from the `state_dict` and then load it. """ root_module = nn.Module() root = "" _ir = {impl.name: impl for impl in ir} for impl in _ir.values(): if impl.type != "module": modules = list(self([impl])) assert len(modules) == 1 name = impl.name last_parent = root_module while "." in name: parent_name, name = name.rsplit(".", 1) last_children = dict(last_parent.named_children()) if parent_name not in last_children: current_parent = nn.Module() last_parent.add_module(parent_name, current_parent) else: current_parent = last_children[parent_name] last_parent = current_parent last_parent.add_module(name, modules[0]) else: root = impl.name graph = fx.Graph() ir_root = _ir[root] for ir_node in ir_root.nodes.values(): if ir_node.type not in ("input", "output"): graph.call_module( ir_node.implementation, tuple(ir_root.predecessors(ir_node)) ) module = fx.GraphModule(root_module, graph) if state_dict is not None: filtered_state_dict = {} for key, value in state_dict.items(): submodule = ".".join(key.split(".")[:-1]) if submodule in _ir: filtered_state_dict[key] = value module.load_state_dict(filtered_state_dict) return module
[docs] def register_type_handlers( self, handlers: Iterable[Callable[[Implementation], nn.Module]] ) -> None: for handler in handlers: self.register(handler.__name__)(handler)
[docs] def get_default_converter() -> Ir2Torch: converter = Ir2Torch() converter.register_type_handlers(handlers) return converter