Source code for elasticai.creator.ir2torch.ir2torch

from collections.abc import Callable
from itertools import starmap
from typing import Any, Protocol

import torch.nn as nn
from torch import fx

import elasticai.creator.function_dispatch as FD
import elasticai.creator.ir.ir_v2 as ir

from .default_handlers import handlers


[docs] class DataGraph(ir.DataGraph[ir.Node, ir.Edge], Protocol): @property def type(self) -> str: ...
class _DataGraph(ir.DataGraphImpl[ir.Node, ir.Edge]): @property def type(self) -> str: result = self.attributes.get("type", "<none>") if not isinstance(result, str): raise TypeError("expected 'type' field to be of type str") return result def _new_graph( factory: ir.NodeEdgeFactory[ir.Node, ir.Edge], attributes: ir.AttributeMapping ) -> DataGraph: return _DataGraph( factory, attributes, ir.GraphImpl(lambda: ir.AttributeMapping()), ir.AttributeMapping(), )
[docs] class IrFactory(ir.IrFactory[ir.Node, ir.Edge, ir.DataGraph]): def __init__(self): self._node_edge = ir.DefaultNodeEdgeFactory() def graph(attributes): return _new_graph(self, attributes) self._graph = graph
[docs] def node( self, name: str, attributes: ir.AttributeMapping = ir.AttributeMapping() ) -> ir.Node: return self._node_edge.node(name, attributes)
[docs] def edge( self, src: str, dst: str, attributes: ir.AttributeMapping = ir.AttributeMapping(), ) -> ir.Edge: return self._node_edge.edge(src, dst, attributes)
[docs] def graph( self, attributes: ir.AttributeMapping = ir.AttributeMapping(), /, *, graph: ir.DataGraph | None = None, ) -> DataGraph: if graph is not None: return _DataGraph( factory=graph.factory, attributes=graph.attributes, graph=graph.graph, node_attributes=graph.node_attributes, ) return self._graph(attributes)
type TypeHandler = Callable[[DataGraph], nn.Module]
[docs] class Ir2Torch: @FD.dispatch_method(str) def _build_submodule( self, fn: TypeHandler, dgraph: DataGraph, /, ) -> nn.Module: return fn(dgraph) @_build_submodule.key_from_args def _get_key_from_data_graph(self, dgraph: DataGraph) -> str: return dgraph.type @staticmethod def _check_and_get_name_from_fn(name, fn) -> str: if name is None: if hasattr(fn, "__name__") and isinstance(fn.__name__, str): return fn.__name__ else: raise TypeError( "If the registered type handler is not a function, you need to specify the type name explicitly" ) return name
[docs] @FD.registrar_method def register( self, name: str | None, fn: TypeHandler, ) -> TypeHandler: return self._build_submodule.register( self._check_and_get_name_from_fn(name, fn), fn )
[docs] @FD.registrar_method def override( self, name: str | None, fn: TypeHandler, ) -> TypeHandler: return self._build_submodule.override( self._check_and_get_name_from_fn(name, fn), fn )
[docs] def __call__( self, ir_root: DataGraph, registry: ir.Registry[DataGraph], state_dict: dict[str, Any] | None = None, ) -> nn.Module: """Rebuild the original pytorch model from a given IR. Implemenation names containing dots will result in the corresponding modules sorted into a corresponding object hierarchy. E.g., for the implementation name `'top.linear'` we will create a pytorch container module under the name `'top'` and add the linear layer to it under the name `'linear'`. Note that this is an implementation detail of Ir2Torch and not a semantic meaning assigned to the `'.'` character. :param: `ir`: You need to make sure that `Ir2Torch` has a type handler for each implementation in `ir` :param: `state_dict`: You can optionally pass a state dict. This should be a state dict created from the original model via `nn.Module.state_dict`. As the `Torch2Ir` stage got rid of all duplicate submodules, we will strip all unknown keys from the `state_dict` and then load it. """ factory = IrFactory() root_module = nn.Module() def to_new_graph(name, graph): return name, factory.graph(graph=graph) for name, impl in starmap(to_new_graph, registry.items()): layer = self._build_submodule(impl) last_parent = root_module while "." in name: parent_name, name = name.rsplit(".", 1) last_children = dict(last_parent.named_children()) if parent_name not in last_children: current_parent = nn.Module() last_parent.add_module(parent_name, current_parent) else: current_parent = last_children[parent_name] last_parent = current_parent last_parent.add_module(name, layer) graph = fx.Graph() nodes: dict[str, fx.Node] = {} def _add_node(ir_node: str) -> None: node_type = ir_root.nodes[ir_node].type node_impl: str = ir_root.nodes[ir_node].attributes["implementation"] if ir_root.nodes[ir_node].type not in ("input", "output"): predecessors = tuple( nodes[node] for node in ir_root.predecessors[ir_node] ) node = graph.create_node( op="call_module", target=node_impl, args=predecessors, name=ir_node, ) nodes[ir_node] = node elif node_type == "input": n = graph.create_node(op="placeholder", name=ir_node, target=ir_node) nodes[ir_node] = n elif node_type == "output": predecessors = tuple( nodes[node] for node in ir_root.predecessors[ir_node] ) n = graph.create_node( op="output", target=ir_node, args=predecessors, ) nodes[ir_node] = n def visit_node(ir_node: str): if ir_node not in nodes: for pred in ir_root.predecessors[ir_node]: visit_node(pred) if ir_node not in nodes: _add_node(ir_node) visit_node("output") module = fx.GraphModule(root_module, graph) if state_dict is not None: filtered_state_dict = {} for key, value in state_dict.items(): submodule = ".".join(key.split(".")[:-1]) if submodule in registry: filtered_state_dict[key] = value module.load_state_dict(filtered_state_dict) return module
[docs] def get_default_converter() -> Ir2Torch: converter = Ir2Torch() for handler in handlers: converter.register()(handler) return converter