Source code for elasticai.creator.torch2ir.torch2ir

from collections.abc import Callable
from typing import cast

from torch.fx import Node as FxNode
from torch.fx import Tracer
from torch.nn import Module

import elasticai.creator.function_dispatch as FD
import elasticai.creator.ir.ir_v2 as ir

from .default_handlers import handlers as default_handlers


class _DefaultTracer(Tracer):
    def is_leaf_module(self, m, module_qualified_name):
        if type(m).__qualname__.startswith("elasticai"):
            return True
        return super().is_leaf_module(m, module_qualified_name)


[docs] class LoweringError(Exception): def __init__(self, message: str): super().__init__(message)
type DataGraph = ir.DataGraph[ir.Node, ir.Edge] type Registry = ir.Registry[DataGraph] type TypeHandler = Callable[[Module], dict]
[docs] class Torch2Ir: def __init__(self, tracer=_DefaultTracer()) -> None: self._tracer = tracer self._ir_factory = ir.DefaultIrFactory() self._registry: Registry = ir.Registry() self._root = self._ir_factory.graph(ir.attribute(type="module")) @FD.dispatch_method(str) def _extractors( self, fn: TypeHandler, module: Module, ) -> dict: return fn(module) @_extractors.key_from_args def _get_type_from_module(self, module: Module) -> str: return module.__class__.__name__.lower() @staticmethod def _check_and_get_name(name: str | None, fn: TypeHandler) -> str: if name is None: if hasattr(fn, "__name__") and isinstance(fn.__name__, str): return fn.__name__ else: raise TypeError( "specify the type handler's type explicitly if you want to register a non-function callable" ) return name
[docs] @FD.registrar_method def register(self, key: str | None, fn: TypeHandler) -> TypeHandler: key = self._check_and_get_name(key, fn) self._extractors.register(key, fn) return fn
def _handle_fx_node(self, node: FxNode) -> None: self._root = self._root.add_node( node.name, ir.attribute( type=self._get_type(node), implementation=self._get_implementation(node) ), ) ir_node = self._root.nodes[node.name] impl = ir_node.attributes["implementation"] if impl not in self._registry and impl not in ("input", "output"): attributes = self._extract_attributes(node) self._registry = self._registry | { impl: self._ir_factory.graph( ir.attribute(type=ir_node.type, **attributes), ) } for successor in self._get_successors(node): self._root = self._root.add_edge(node.name, successor.name)
[docs] def convert(self, model: Module) -> tuple[DataGraph, Registry]: self.model = model torch_graph = self._tracer.trace(model) registry: Registry = ir.Registry() for node in torch_graph.nodes: self._handle_fx_node(node) registry = self._registry self._registry = ir.Registry() root = self._root self._root = self._ir_factory.graph(ir.attribute(type="module")) return root, registry
def _get_successors(self, node: FxNode) -> list[FxNode]: return list(node.users) def _get_type(self, node: FxNode) -> str: error = LoweringError(""" Cannot handle function calls or getting attributes during translation. Please use supported Modules to design your model and change every function call to a module call.""") match node.op: case "call_module": return type( self.model.get_submodule(cast(str, node.target)) ).__name__.lower() case "call_function": raise error case "placeholder": return "input" case "output": return "output" case "get_attr": raise error case _: raise Exception(f"Unknown node type: {node.op}") def _get_implementation(self, node: FxNode) -> str: match self._get_type(node): case "input": return "input" case "output": return "output" case _: return cast(str, node.target) def _extract_attributes(self, node: FxNode) -> dict: if self._get_type(node) in ("input", "output"): return {} module = self.model.get_submodule(cast(str, node.target)) return self._extractors(module)
[docs] def __call__(self, model: Module) -> tuple[DataGraph, Registry]: return self.convert(model)
[docs] def get_default_converter() -> Torch2Ir: converter = Torch2Ir() for handler in default_handlers: converter.register()(handler) return converter