Source code for elasticai.creator.ir.core.lowering
import warnings
from abc import abstractmethod
from collections.abc import Callable, Iterable, Iterator
from functools import wraps
from typing import Generic, ParamSpec, Protocol, TypeVar
from elasticai.creator.function_utils import KeyedFunctionDispatcher as _Registry
from elasticai.creator.function_utils import RegisterDescriptor
[docs]
class Lowerable(Protocol):
@property
@abstractmethod
def type(self) -> str: ...
Tin = TypeVar("Tin", bound="Lowerable")
Tout = TypeVar("Tout")
[docs]
class LoweringPass(Generic[Tin, Tout]):
register: RegisterDescriptor[Tin, Tout] = RegisterDescriptor()
register_override: RegisterDescriptor[Tin, Tout] = RegisterDescriptor()
register_iterable: RegisterDescriptor[Tin, Iterable[Tout]] = RegisterDescriptor()
register_iterable_override: RegisterDescriptor[Tin, Iterable[Tout]] = (
RegisterDescriptor()
)
def __init__(self) -> None:
def key_lookup_fn(x: Tin) -> str:
return x.type
self._fns: _Registry[Tin, Iterable[Tout]] = _Registry(key_lookup_fn)
def _register_callback(self, name: str, fn: Callable[[Tin], Tout]):
self._check_for_redefinition(name)
wrapper = return_as_iterable(fn)
self._fns.register(name)(wrapper)
def _register_override_callback(self, name: str, fn: Callable[[Tin], Tout]):
self._check_for_override(name)
wrapper = return_as_iterable(fn)
self._fns.register(name)(wrapper)
def _register_iterable_callback(
self, name: str, fn: Callable[[Tin], Iterable[Tout]]
):
self._check_for_redefinition(name)
self._fns.register(name)(fn)
def _register_iterable_override_callback(
self, name: str, fn: Callable[[Tin], Iterable[Tout]]
):
self._check_for_override(name)
self._fns.register(name)(fn)
[docs]
def __call__(self, args: Iterable[Tin]) -> Iterator[Tout]:
for arg in args:
yield from self._fns(arg)
def _check_for_redefinition(self, arg):
if arg in self._fns:
raise ValueError(f"function for {arg} already defined in lowering pass")
def _check_for_override(self, arg):
if arg not in self._fns:
warnings.warn(
"expected to override registered function for {}, but no function for that type was defined".format(
arg
),
stacklevel=3,
)
P = ParamSpec("P")
[docs]
def return_as_iterable(fn: Callable[P, Tout]) -> Callable[P, Iterable[Tout]]:
@wraps(fn)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> Iterable[Tout]:
yield fn(*args, **kwargs)
return wrapper