Source code for elasticai.creator.ir2verilog.ir2verilog

import warnings
from collections.abc import Callable, Iterable, Iterator, Sequence
from dataclasses import dataclass
from importlib import resources as res
from itertools import starmap
from typing import Any, Protocol, TypeAlias, override

import elasticai.creator.function_dispatch as FD
from elasticai.creator import plugin as _pl
from elasticai.creator.hdl_ir import (
    DataGraph,
    IrFactory,
    NonIterableTypeHandler,
    Registry,
    TypeHandler,
    _check_and_get_name_fn,
)
from elasticai.creator.ir import ir_v2 as ir
from elasticai.creator.plugin import PluginLoaderBase, StaticFileBase

factory = IrFactory()

Code: TypeAlias = tuple[str, Sequence[str]]


[docs] @dataclass class PluginSpec(_pl.PluginSpec): generated: tuple[str, ...] static_files: tuple[str, ...]
[docs] class Ir2Verilog: def __init__(self) -> None: self.__static_files: dict[str, Callable[[], str]] = {}
[docs] def __call__( self, root: DataGraph, registry: Registry, default_root_name="root" ) -> Iterable[Code]: registry = self._give_names_to_registry_graphs(registry) if "name" not in root.attributes: root = root.with_attributes(root.attributes | dict(name=default_root_name)) yield from self._handle_type(root, registry) for g in registry.values(): for name, code in self._handle_type(g, registry): yield f"{name}.v", code for name, fn in self.__static_files.items(): yield name, fn()
def _give_names_to_registry_graphs(self, registry: Registry) -> Registry: def give_name(name: str, g: DataGraph) -> tuple[str, DataGraph]: if "name" not in g.attributes: return name, g.with_attributes(g.attributes | dict(name=name)) return name, g return ir.Registry(starmap(give_name, registry.items())) @FD.dispatch_method(str) def _handle_type( self, fn: TypeHandler, graph: DataGraph, registry: Registry ) -> Iterable[Code]: return fn(graph, registry) @_handle_type.key_from_args def _get_key(self, graph: DataGraph, registry: Registry) -> str: return graph.type @staticmethod def _check_and_get_name(name: str | None, fn: Callable) -> str: return _check_and_get_name_fn(name, fn)
[docs] @FD.registrar_method def register_static( self, name: str | None, fn: Callable[[], str] ) -> Callable[[], str]: self.__static_files[self._check_and_get_name(name, fn)] = fn return fn
[docs] @FD.registrar_method def register(self, name: str | None, fn: TypeHandler) -> TypeHandler: name = self._check_and_get_name(name, fn) self._handle_type.register(name, fn) return fn
[docs] @FD.registrar_method def override(self, name: str | None, fn: TypeHandler) -> TypeHandler: name = self._check_and_get_name(name, fn) self._handle_type.override(name, fn) return fn
[docs] class PluginSymbol(Protocol):
[docs] def load_verilog(self, receiver: Ir2Verilog) -> None: ...
class _StaticFileSymbol: def __init__(self, file: StaticFileBase): self._file = file def load_verilog(self, receiver: Ir2Verilog) -> None: receiver.register_static(self._file.name, self._file.get_content)
[docs] class PluginLoader(PluginLoaderBase): """PluginLoader for Ir2Verilog passes.""" def __init__(self, lowering: Ir2Verilog): self._receiver = lowering super().__init__(PluginSpec)
[docs] @override def filter_plugin_dicts( self, plugins: Iterable[dict[str, Any]] ) -> Iterable[dict[str, Any]]: for p in plugins: if p["target_runtime"] == "verilog": yield p
[docs] @override def load_symbol(self, symbol: PluginSymbol) -> None: if hasattr(symbol, "load_verilog"): symbol.load_verilog(self._receiver) elif hasattr(symbol, "load_into"): warnings.warn( "Loading legacy plugin symbol, this behaviour will be removed in the future ensure your plugin symbol provides a load_verilog method", stacklevel=2, category=DeprecationWarning, ) symbol.load_into(self._receiver) else: raise TypeError("Failed to load plugin symbol")
[docs] @override def get_symbols(self, specs: Iterable[PluginSpec]) -> Iterable[PluginSymbol]: for spec in specs: yield from _pl.import_symbols(spec.package, spec.generated) for static_name in spec.static_files: yield _StaticFileSymbol( _pl.StaticFileBase( name=static_name, package=spec.package, subfolder="verilog" ) )
class _StaticFile: _subfolder = "verilog" def __init__(self, name: str, package: str): self._name = name self._package = package @property def name(self) -> str: return self._name def load_verilog(self, receiver: Ir2Verilog): receiver.register_static(self.name, self) @classmethod def make_symbols(cls, p: PluginSpec) -> Iterator[PluginSymbol]: if p.target_runtime == cls._subfolder: for name in p.static_files: yield cls(name=name, package=p.package) def __call__(self) -> str: file = res.files(self._package).joinpath(f"{self._subfolder}/{self.name}") return file.read_text()
[docs] @FD.registrar def type_handler( name: str | None, fn: NonIterableTypeHandler ) -> NonIterableTypeHandler: name = _check_and_get_name_fn(name, fn) def load_into(lower: Ir2Verilog) -> None: def wrapper(*args, **kwargs): yield fn(*args, **kwargs) lower.register(name, wrapper) setattr(fn, "load_verilog", load_into) return fn
[docs] @FD.registrar def type_handler_iterable(name: str | None, fn: TypeHandler) -> TypeHandler: name = _check_and_get_name_fn(name, fn) if name is None: name = fn.__name__ def load_into(lower: Ir2Verilog) -> None: lower.register(name, fn) setattr(fn, "load_verilog", load_into) return fn