Source code for elasticai.creator.ir2verilog.ir2verilog

import warnings
from collections.abc import Callable, Iterable, Sequence
from dataclasses import dataclass
from itertools import starmap
from typing import Any, Protocol, TypeAlias, override

import elasticai.creator.function_dispatch as FD
from elasticai.creator import ir
from elasticai.creator import plugin as _pl
from elasticai.creator.hdl_ir import (
    DataGraph,
    IrFactory,
    NonIterableTypeHandler,
    Registry,
    TypeHandler,
    _check_and_get_name_fn,
)
from elasticai.creator.plugin import PluginLoaderBase, StaticFileBase

factory = IrFactory()

Code: TypeAlias = tuple[str, Sequence[str]]


[docs] @dataclass class PluginSpec(_pl.PluginSpec): generated: tuple[str, ...] static_files: tuple[str, ...]
[docs] class Ir2Verilog: def __init__(self) -> None: self.__static_files: dict[str, Callable[[], str]] = {}
[docs] def __call__( self, root: DataGraph, registry: Registry, default_root_name="root" ) -> Iterable[Code]: registry = self._give_names_to_registry_graphs(registry) if "name" not in root.attributes: root = root.with_attributes(root.attributes | dict(name=default_root_name)) for name, code in self._handle_type(root, registry): yield f"{name}.v", code for g in registry.values(): for name, code in self._handle_type(g, registry): yield f"{name}.v", code for name, fn in self.__static_files.items(): yield name, (fn(),)
def _give_names_to_registry_graphs(self, registry: Registry) -> Registry: def give_name(name: str, g: DataGraph) -> tuple[str, DataGraph]: if "name" not in g.attributes: return name, g.with_attributes(g.attributes | dict(name=name)) return name, g return ir.Registry(starmap(give_name, registry.items())) @FD.dispatch_method(str) def _handle_type( self, fn: TypeHandler, graph: DataGraph, registry: Registry ) -> Iterable[Code]: return fn(graph, registry) @_handle_type.key_from_args def _get_key(self, graph: DataGraph, registry: Registry) -> str: return graph.type @staticmethod def _check_and_get_name(name: str | None, fn: Callable) -> str: return _check_and_get_name_fn(name, fn)
[docs] @FD.registrar_method def register_static( self, name: str | None, fn: Callable[[], str] ) -> Callable[[], str]: self.__static_files[self._check_and_get_name(name, fn)] = fn return fn
[docs] @FD.registrar_method def register(self, name: str | None, fn: TypeHandler) -> TypeHandler: name = self._check_and_get_name(name, fn) self._handle_type.register(name, fn) return fn
[docs] @FD.registrar_method def override(self, name: str | None, fn: TypeHandler) -> TypeHandler: name = self._check_and_get_name(name, fn) self._handle_type.override(name, fn) return fn
[docs] class PluginSymbol(Protocol):
[docs] def load_verilog(self, receiver: Ir2Verilog) -> None: ...
class _StaticFileSymbol: def __init__(self, file: StaticFileBase): self._file = file def load_verilog(self, receiver: Ir2Verilog) -> None: receiver.register_static(self._file.name, self._file.get_content)
[docs] class PluginLoader(PluginLoaderBase): """PluginLoader for Ir2Verilog passes.""" def __init__(self, lowering: Ir2Verilog): self._receiver = lowering super().__init__(PluginSpec)
[docs] @override def filter_plugin_dicts( self, plugins: Iterable[dict[str, Any]] ) -> Iterable[dict[str, Any]]: for p in plugins: if p["target_runtime"] == "verilog": yield p
[docs] @override def load_symbol(self, symbol: PluginSymbol) -> None: if hasattr(symbol, "load_verilog"): symbol.load_verilog(self._receiver) elif hasattr(symbol, "load_into"): warnings.warn( "Loading legacy plugin symbol, this behaviour will be removed in the future ensure your plugin symbol provides a load_verilog method", stacklevel=2, category=DeprecationWarning, ) symbol.load_into(self._receiver) else: raise TypeError("Failed to load plugin symbol")
[docs] @override def get_symbols(self, specs: Iterable[PluginSpec]) -> Iterable[PluginSymbol]: for spec in specs: yield from _pl.import_symbols(spec.package, spec.generated) for static_name in spec.static_files: yield _StaticFileSymbol( _pl.StaticFileBase( name=static_name, package=spec.package, subfolder="verilog" ) )
[docs] @FD.registrar def type_handler( name: str | None, fn: NonIterableTypeHandler ) -> NonIterableTypeHandler: name = _check_and_get_name_fn(name, fn) def load_into(lower: Ir2Verilog) -> None: def wrapper(*args, **kwargs): yield fn(*args, **kwargs) lower.register(name, wrapper) setattr(fn, "load_verilog", load_into) return fn
[docs] @FD.registrar def type_handler_iterable(name: str | None, fn: TypeHandler) -> TypeHandler: name = _check_and_get_name_fn(name, fn) if name is None: name = fn.__name__ def load_into(lower: Ir2Verilog) -> None: lower.register(name, fn) setattr(fn, "load_verilog", load_into) return fn