Source code for elasticai.creator.torch2ir.torch2ir
from collections.abc import Callable, Iterable, Iterator
from typing import cast
from torch.fx import Node as FxNode
from torch.fx import Tracer
from torch.nn import Module
from elasticai.creator.function_utils import KeyedFunctionDispatcher
from elasticai.creator.graph import BaseGraph
from .core import Edge, Implementation, new_node
from .default_handlers import handlers as default_handlers
class _DefaultTracer(Tracer):
def is_leaf_module(self, m, module_qualified_name):
if type(m).__qualname__.startswith("elasticai"):
return True
return super().is_leaf_module(m, module_qualified_name)
[docs]
class LoweringError(Exception):
def __init__(self, message: str):
super().__init__(message)
[docs]
class Torch2Ir:
def __init__(self, tracer: Tracer = _DefaultTracer()):
super().__init__()
self._tracer = tracer
self._registry: dict[str, Implementation] = {}
self._root = Implementation(graph=BaseGraph())
self._root.type = "module"
self._root.name = ""
self._registry[""] = self._root
self._extractors: KeyedFunctionDispatcher[Module, dict] = (
KeyedFunctionDispatcher(self._get_module_key)
)
[docs]
def register(
self, module_type: str, handler: Callable[[Module], dict]
) -> Callable[[Module], dict]:
"""The handlers are used to extract the attributes of the module"""
self._extractors.register(module_type, handler)
return handler
[docs]
def register_handlers(
self, handlers: Iterable[Callable[[Module], dict]]
) -> "Torch2Ir":
for handler in handlers:
self.register(handler.__name__, handler)
return self
@staticmethod
def _get_module_key(module: Module) -> str:
return module.__class__.__name__.lower()
[docs]
def convert(self, model: Module) -> Iterator[Implementation]:
self.model = model
torch_graph = self._tracer.trace(model)
for node in torch_graph.nodes:
self._handle_fx_node(node)
registry = self._registry
self._registry = {}
yield from registry.values()
[docs]
def __call__(self, model: Module) -> Iterator[Implementation]:
yield from self.convert(model)
def _get_successors(self, node: FxNode) -> list[FxNode]:
return list(node.users)
def _get_type(self, node: FxNode) -> str:
error = LoweringError("""
Cannot handle function calls or getting
attributes during translation. Please use
supported Modules to design your model
and change every function call to a
module call.""")
match node.op:
case "call_module":
return type(
self.model.get_submodule(cast(str, node.target))
).__name__.lower()
case "call_function":
raise error
case "placeholder":
return "input"
case "output":
return "output"
case "get_attr":
raise error
case _:
raise Exception(f"Unknown node type: {node.op}")
def _get_implementation(self, node: FxNode) -> str:
match self._get_type(node):
case "input":
return "input"
case "output":
return "output"
case _:
return cast(str, node.target)
def _handle_fx_node(self, node: FxNode) -> None:
ir_node = new_node(
name=node.name,
type=self._get_type(node),
implementation=self._get_implementation(node),
attributes={},
)
self._root.add_node(ir_node)
impl = ir_node.implementation
if impl not in self._registry and impl not in ("input", "output"):
self._registry[impl] = Implementation(
graph=BaseGraph(),
data=dict(
name=impl, type=ir_node.type, **self._extract_attributes(node)
),
)
for successor in self._get_successors(node):
self._root.add_edge(Edge(src=node.name, dst=successor.name, data={}))
def _extract_attributes(self, node: FxNode) -> dict:
if self._get_type(node) in ("input", "output"):
return {}
module = self.model.get_submodule(cast(str, node.target))
return self._extractors(module)
[docs]
def get_default_converter() -> Torch2Ir:
converter = Torch2Ir()
converter.register_handlers(default_handlers)
return converter