Source code for elasticai.creator.ir2verilog.ir2verilog

from collections.abc import Callable, Iterable, Iterator, Sequence
from dataclasses import dataclass
from importlib import resources as res
from typing import Protocol, TypeAlias

import elasticai.creator.function_utils as F
from elasticai.creator import graph as g
from elasticai.creator import ir
from elasticai.creator import plugin as pl
from elasticai.creator.ir import Edge as Edge


[docs] class Node(ir.Node): implementation: str
Code: TypeAlias = tuple[str, Sequence[str]]
[docs] @dataclass class PluginSpec(pl.PluginSpec): generated: tuple[str, ...] static_files: tuple[str, ...]
[docs] class Implementation(ir.Implementation[Node, ir.Edge]): def __init__(self, *, graph: g.Graph[str], data: dict[str, ir.Attribute]): super().__init__(graph=graph, data=data)
[docs] class Ir2Verilog(ir.LoweringPass[Implementation, Code]): def __init__(self) -> None: super().__init__() self.__static_files: dict[str, Callable[[], str]] = {} self._loader = PluginLoader(self)
[docs] def register_static(self, name: str, fn: Callable[[], str]) -> None: self.__static_files[name] = fn
[docs] def __call__(self, args: Iterable[Implementation]) -> Iterator[Code]: for name, content in super().__call__(args): yield f"{name}.v", content for name, fn in self.__static_files.items(): yield name, fn()
[docs] def load_from_package(self, package: str) -> None: self._loader.load_from_package(package)
[docs] class PluginSymbol(pl.PluginSymbol[Ir2Verilog], Protocol): pass
[docs] class PluginLoader(pl.PluginLoader[Ir2Verilog]): """PluginLoader for Ir2Verilog passes.""" def __init__(self, lowering: Ir2Verilog): builder: pl.SymbolFetcherBuilder[PluginSpec, Ir2Verilog] = ( pl.SymbolFetcherBuilder(PluginSpec) ) fetcher: pl.SymbolFetcher[Ir2Verilog] = ( builder.add_fn(self.__get_generated) .add_fn(_StaticFile.make_symbols) .build() ) super().__init__( fetch=fetcher, plugin_receiver=lowering, )
[docs] def load_from_package(self, package: str) -> None: if "." not in package: package = f"elasticai.creator_plugins.{package}" super().load_from_package(package)
@staticmethod def __get_generated(plugin: PluginSpec) -> Iterator[PluginSymbol]: if plugin.target_runtime == "verilog": yield from pl.import_symbols(plugin.package, plugin.generated)
class _StaticFile(PluginSymbol): _subfolder = "verilog" def __init__(self, name: str, package: str): self._name = name self._package = package @property def name(self) -> str: return self._name def load_into(self, receiver: Ir2Verilog): receiver.register_static(self.name, self) @classmethod def make_symbols(cls, p: PluginSpec) -> Iterator[pl.PluginSymbol[Ir2Verilog]]: if p.target_runtime == cls._subfolder: for name in p.static_files: yield cls(name=name, package=p.package) def __call__(self) -> str: file = res.files(self._package).joinpath(f"{self._subfolder}/{self.name}") return file.read_text() TypeHandlerFn: TypeAlias = Callable[[Implementation], Code] def _type_handler( name: str, fn: TypeHandlerFn ) -> pl.PluginSymbolFn[Ir2Verilog, [Implementation], Code]: def load_into(lower: Ir2Verilog) -> None: lower.register(name)(fn) return pl.make_plugin_symbol(load_into, fn) def _type_handler_for_iterable( name: str, fn: Callable[[Implementation], Iterable[Code]] ) -> pl.PluginSymbolFn[Ir2Verilog, [Implementation], Iterable[Code]]: def load_into(lower: Ir2Verilog) -> None: lower.register_iterable(name)(fn) return pl.make_plugin_symbol(load_into, fn) type_handler = F.FunctionDecorator(_type_handler) type_handler_iterable = F.FunctionDecorator(_type_handler_for_iterable)