import importlib.resources as res
import operator
import re
import warnings
from abc import ABC, abstractmethod
from collections.abc import Callable, Iterable, Sequence
from dataclasses import dataclass
from functools import reduce
from typing import Any, Iterator, TypeAlias, TypeGuard, TypeVar, overload
import elasticai.creator.function_utils as F
import elasticai.creator.plugin as _pl
from elasticai.creator.function_utils import KeyedFunctionDispatcher
from elasticai.creator.graph import BaseGraph
from elasticai.creator.ir import (
Attribute,
LoweringPass,
RequiredField,
)
from elasticai.creator.ir import Edge as _Edge
from elasticai.creator.ir import (
Implementation as _Implementation,
)
from elasticai.creator.ir import Node as _Node
from elasticai.creator.plugin import PluginLoader as _Loader
from elasticai.creator.plugin import PluginSpec as _PluginSpec
from elasticai.creator.plugin import PluginSymbol as _PluginSymbol
[docs]
@dataclass
class PluginSpec(_PluginSpec):
generated: tuple[str, ...]
static_files: tuple[str, ...]
ShapeTuple: TypeAlias = tuple[int] | tuple[int, int] | tuple[int, int, int]
[docs]
def is_shape_tuple(values) -> TypeGuard[ShapeTuple]:
max_num_values = 3
return len(values) <= max_num_values
[docs]
class Shape:
@overload
def __init__(self, width: int, /) -> None: ...
@overload
def __init__(self, depth: int, width: int, /) -> None: ...
@overload
def __init__(self, depth: int, width: int, height: int, /) -> None: ...
def __init__(self, *values: int) -> None:
"""values are interpreted as one of the following:
- width
- depth, width
- depth, width, height
Usually width is kernel_size, depth is channels
"""
if is_shape_tuple(values):
self._data = values
else:
raise TypeError(f"taking at most three ints, given {values}")
[docs]
@classmethod
def from_tuple(cls, values: ShapeTuple | list[int]) -> "Shape":
return cls(*values) # type ignore
[docs]
def to_tuple(self) -> ShapeTuple:
return self._data
[docs]
def to_list(self) -> list[int]:
return list(self.to_tuple())
[docs]
def __getitem__(self, item):
return self._data[item]
[docs]
def size(self) -> int:
return reduce(operator.mul, self._data, 1)
[docs]
def ndim(self) -> int:
return len(self._data)
@property
def depth(self) -> int:
return self._data[0]
[docs]
def __eq__(self, other):
if isinstance(other, tuple):
return self._data == other
if isinstance(other, Shape):
return self._data == other._data
return False
@property
def width(self) -> int:
if len(self._data) > 1:
return self._data[1]
else:
return 1
@property
def height(self) -> int:
if len(self._data) > 2:
return self._data[2]
return 1
[docs]
def __repr__(self) -> str:
match self._data:
case (width,):
return f"Shape({width=})"
case (depth, width):
return f"Shape({depth=}, {width=})"
case (depth, width, height):
return f"Shape({depth=}, {width=}, {height=})"
case _:
return f"Shape({self._data})"
[docs]
class ShapeField(RequiredField[list[int], Shape]):
def __init__(self):
super().__init__(
set_convert=lambda x: x.to_list(), get_convert=Shape.from_tuple
)
[docs]
class VhdlNode(_Node):
"""Extending ir.core.Node to a vhdl specific node.
`VhdlNode` contains all knowledge that we need to create
and use an instance of a vhdl entity. However, this becomes
a little bit complicated because vhdl differentiates between
the *entity* and the *architecture* of a component.
The entity is similar to an _interface_ while the architecture
is similar to the _implementation_. However, to instantiate
components, we need to know both names.
Attributes:
implementation:: The name of the implementation will be used to derive the architecture name.
E.g., if the implementation is `"adder"`, we will instantiate the entity `work.adder(rtl)`.
CAUTION: This behaviour is subject to change. Future versions might require the full entity name
"""
implementation: str
input_shape: RequiredField[list[int], Shape] = ShapeField()
output_shape: RequiredField[list[int], Shape] = ShapeField()
[docs]
def vhdl_node(
name: str,
type: str,
implementation: str,
input_shape: Shape | ShapeTuple,
output_shape: Shape | ShapeTuple,
attributes: dict | None = None,
) -> VhdlNode:
"""Convenience method for creating new vhdl nodes."""
if attributes is None:
attributes = {}
def to_tuple(s: Shape | ShapeTuple) -> ShapeTuple:
if isinstance(s, Shape):
return s.to_tuple()
else:
return s
return VhdlNode(
name=name,
data=dict(
type=type,
implementation=implementation,
input_shape=to_tuple(input_shape),
output_shape=to_tuple(output_shape),
)
| attributes,
)
[docs]
class Edge(_Edge):
src_dst_indices: tuple[tuple[int, int] | tuple[str, str], ...]
[docs]
def __hash__(self) -> int:
return hash((self.src, self.dst, self.src_dst_indices))
[docs]
def edge(
src: str, dst: str, src_dst_indices: Iterable[tuple[int, int]] | tuple[str, str]
) -> Edge:
return Edge(src=src, dst=dst, data={"src_dst_indices": tuple(src_dst_indices)})
N = TypeVar("N", bound=VhdlNode)
E = TypeVar("E", bound=Edge)
[docs]
class Implementation(_Implementation[N, E]):
name: str
type: str
@overload
def __init__(
self: "Implementation[VhdlNode, Edge]",
*,
name: str | None = None,
type: str | None = None,
data: dict[str, Attribute] | None = None,
attributes: dict[str, Attribute] | None = None,
) -> None: ...
@overload
def __init__(
self: "Implementation[N, Edge]",
*,
node_fn: Callable[[dict], N],
name: str | None = None,
type: str | None = None,
data: dict[str, Attribute] | None = None,
attributes: dict[str, Attribute] | None = None,
graph: BaseGraph | None = None,
) -> None: ...
@overload
def __init__(
self: "Implementation[VhdlNode, E]",
*,
edge_fn: Callable[[dict], E],
name: str | None = None,
type: str | None = None,
data: dict[str, Attribute] | None = None,
attributes: dict[str, Attribute] | None = None,
graph: BaseGraph | None = None,
) -> None: ...
@overload
def __init__(
self: "Implementation[N, E]",
*,
node_fn: Callable[[dict], N],
edge_fn: Callable[[dict], E],
name: str | None = None,
type: str | None = None,
data: dict[str, Attribute] | None = None,
attributes: dict[str, Attribute] | None = None,
graph: BaseGraph | None = None,
) -> None: ...
def __init__(
self,
*,
name: str | None = None,
type: str | None = None,
node_fn=VhdlNode,
edge_fn=Edge,
data: dict[str, Any] | None = None,
attributes: dict[str, Any] | None = None,
graph: BaseGraph | None = None,
) -> None:
if attributes is not None and data is not None:
raise TypeError("pass either attributes or data argument")
if attributes is not None:
warnings.warn(
"the argument `attributes` is deprecated, use `data` instead.",
DeprecationWarning,
2,
)
data = attributes
if graph is None:
graph = BaseGraph()
super().__init__(
node_fn=node_fn,
edge_fn=edge_fn,
data=data,
graph=graph,
)
if name is not None:
self.data["name"] = name
if type is not None:
self.data["type"] = type
Code: TypeAlias = tuple[str, Sequence[str]]
[docs]
class Ir2Vhdl(LoweringPass[Implementation, Code]):
def __init__(self) -> None:
super().__init__()
self.__static_files: dict[str, Callable[[], str]] = {}
[docs]
def register_static(self, name: str, fn: Callable[[], str]) -> None:
self.__static_files[name] = fn
[docs]
def __call__(self, args: Iterable[Implementation]) -> Iterator[Code]:
for name, content in super().__call__(args):
yield f"{name}.vhd", content
for name, fn in self.__static_files.items():
yield name, [fn()]
[docs]
class Signal(ABC):
types: set[type["Signal"]] = set()
[docs]
@abstractmethod
def define(self) -> Iterator[str]: ...
@property
@abstractmethod
def name(self) -> str: ...
[docs]
@classmethod
def from_code(cls, code: str) -> "Signal":
for t in cls.types:
if t.can_create_from_code(code):
return t.from_code(code)
return NullDefinedLogicSignal.from_code(code)
[docs]
@classmethod
@abstractmethod
def can_create_from_code(cls, code: str) -> bool: ...
[docs]
@classmethod
def register_type(cls, t: type["Signal"]) -> None:
cls.types.add(t)
[docs]
@abstractmethod
def make_instance_specific(self, instance: str) -> "Signal": ...
[docs]
class LogicSignal(Signal):
def __init__(self, name: str):
self._name = name
[docs]
def define(self) -> Iterator[str]:
yield f"signal {self._name} : std_logic := '0';"
@property
def name(self) -> str:
return self._name
[docs]
@classmethod
def can_create_from_code(cls, code: str) -> bool:
return cls._search(code) is not None
@classmethod
def _search(cls, code: str) -> re.Match[str] | None:
match = re.search(
r"signal ([a-zA-Z_][a-zA-Z0-9_]*)\s*:\s*std_logic(?:\s+|;)", code
)
return match
[docs]
@classmethod
def from_code(cls, code: str) -> "Signal":
match = cls._search(code)
if match is None:
raise ValueError(f"Cannot create signal from code: {code}")
(name,) = match.groups()
return cls(name)
[docs]
def make_instance_specific(self, instance: str) -> Signal:
return self.__class__(f"{self.name}_{instance}")
[docs]
def __eq__(self, other: object) -> bool:
if other is self:
return True
if isinstance(other, LogicSignal):
return self._name == other._name
return False
[docs]
class LogicVectorSignal(Signal):
def __init__(self, name: str, width: int):
self._name = name
self._width = width
[docs]
def define(self) -> Iterator[str]:
yield f"signal {self._name} : std_logic_vector({self._width} - 1 downto 0) := (others => '0');"
@property
def name(self) -> str:
return self._name
@property
def width(self) -> int:
return self._width
[docs]
@classmethod
def can_create_from_code(cls, code: str) -> bool:
return cls._search(code) is not None
@classmethod
def _search(cls, code: str) -> re.Match[str] | None:
match = re.match(
r"signal ([a-zA-Z_][a-zA-Z0-9_]*)\s*: std_logic_vector\((\d+|(?:\d+ - \d+)) downto 0\)",
code,
)
return match
[docs]
@classmethod
def from_code(cls, code: str) -> "Signal":
match = cls._search(code)
if match is None:
raise ValueError(f"Cannot create signal from code: {code}")
name, width = match.groups()
if " - " in width:
a, b = width.split(" - ")
width = str(int(a) - int(b))
return cls(name, int(width) + 1)
[docs]
def make_instance_specific(self, instance: str) -> Signal:
return self.__class__(f"{self.name}_{instance}", self.width)
[docs]
def __eq__(self, other) -> bool:
if other is self:
return True
if isinstance(other, LogicVectorSignal):
return self._name == other._name and self._width == other._width
return False
[docs]
class NullDefinedLogicSignal(Signal):
def __init__(self, name):
self._name = name
[docs]
def define(self) -> Iterator[str]:
yield from []
@property
def name(self) -> str:
return self._name
[docs]
@classmethod
def can_create_from_code(cls, code: str) -> bool:
return False
[docs]
@classmethod
def from_code(cls, code: str) -> "Signal":
return cls("<unknown>")
[docs]
def make_instance_specific(self, instance: str) -> Signal:
return self
for t in (LogicSignal, LogicVectorSignal, NullDefinedLogicSignal):
Signal.register_type(t)
[docs]
class PortMap:
def __init__(self, map: dict[str, Signal]):
self._signals: dict[str, Signal] = map
[docs]
def as_dict(self) -> dict[str, str]:
return {k: tuple(v.define())[0] for k, v in self._signals.items()}
[docs]
@classmethod
def from_dict(cls, data: dict[str, str]) -> "PortMap":
return cls({k: Signal.from_code(v) for k, v in data.items()})
[docs]
def __eq__(self, other: object) -> bool:
if other is self:
return True
if isinstance(other, PortMap):
return self._signals == other._signals
return False
[docs]
class Instance:
"""Represents an entity that we can/want instantiate.
The aggregates all the knowledge that is necessary to
instantiate and use the corresponding entity programmatically,
when generating vhdl code.
"""
def __init__(
self,
node: VhdlNode,
generic_map: dict[str, str],
port_map: dict[str, Signal],
):
self._node = node
self._generics: dict[str, str] = {k.lower(): v for k, v in generic_map.items()}
self.port_map = {
k: v.make_instance_specific(self._node.name) for k, v in port_map.items()
}
@property
def input_shape(self) -> Shape:
return self._node.input_shape
@property
def output_shape(self) -> Shape:
return self._node.output_shape
@property
def name(self) -> str:
return self._node.name
@property
def implementation(self) -> str:
return self._node.implementation
[docs]
def define_signals(self) -> Iterator[str]:
for s in self.port_map.values():
yield from s.define()
[docs]
def instantiate(self) -> Iterator[str]:
yield from (f"{self.name}: entity work.{self.implementation}(rtl) ",)
generics = tuple(self._generics.items())
if len(generics) > 0:
yield "generic map ("
for key, value in generics[:-1]:
yield f" {key.upper()} => {value},"
for g in generics[-1:]:
yield f" {g[0].upper()} => {g[1]}"
yield " )"
port_map = tuple(self.port_map.items())
yield " port map ("
for k, v in port_map[:-1]:
yield f" {k} => {v.name},"
for k, v in port_map[-1:]:
yield f" {k} => {v.name}"
yield " );"
[docs]
class InstanceFactory(KeyedFunctionDispatcher[VhdlNode, Instance]):
"""Automatically creates Instances from VhdlNodes based on their `type` field."""
def __init__(self):
def dispatch_key_fn(node: VhdlNode) -> str:
return node.type
super().__init__(dispatch_key_fn=dispatch_key_fn)
PluginSymbol: TypeAlias = _PluginSymbol[Ir2Vhdl]
[docs]
class PluginLoader(_Loader[Ir2Vhdl]):
"""Plugin loader for ir2vhdl translation."""
def __init__(self, lowering: Ir2Vhdl):
builder: _pl.SymbolFetcherBuilder[PluginSpec, Ir2Vhdl] = (
_pl.SymbolFetcherBuilder(PluginSpec)
)
fetcher: _pl.SymbolFetcher[Ir2Vhdl] = (
builder.add_fn(self.__get_generated)
.add_fn(_StaticFile.make_symbols)
.build()
)
super().__init__(
fetch=fetcher,
plugin_receiver=lowering,
)
[docs]
def load_from_package(self, package: str) -> None:
if "." not in package:
package = f"elasticai.creator_plugins.{package}"
super().load_from_package(package)
@staticmethod
def __get_generated(plugin: PluginSpec) -> Iterator[PluginSymbol]:
if plugin.target_runtime == "vhdl":
yield from _pl.import_symbols(plugin.package, plugin.generated)
class _StaticFile(_PluginSymbol[Ir2Vhdl]):
def __init__(self, name: str, package: str):
self._name = name
self._package = package
@property
def name(self) -> str:
return self._name
def load_into(self, receiver: Ir2Vhdl):
receiver.register_static(self.name, self)
@classmethod
def make_symbols(cls, p: PluginSpec) -> Iterator[PluginSymbol]:
if p.target_runtime == "vhdl":
for name in p.static_files:
yield cls(name=name, package=p.package)
def __call__(self) -> str:
file = res.files(self._package).joinpath(f"vhdl/{self.name}")
return file.read_text()
_Tcontra = TypeVar("_Tcontra", contravariant=True)
TypeHandlerFn: TypeAlias = Callable[[Implementation], Code]
def _type_handler(name: str, fn: TypeHandlerFn) -> PluginSymbol:
def load_into(lower: Ir2Vhdl) -> None:
lower.register(name)(fn)
return _pl.make_plugin_symbol(load_into, fn)
def _type_handler_for_iterable(
name: str, fn: Callable[[Implementation], Iterable[Code]]
) -> PluginSymbol:
def load_into(lower: Ir2Vhdl) -> None:
lower.register_iterable(name)(fn)
return _pl.make_plugin_symbol(load_into, fn)
type_handler = F.FunctionDecorator(_type_handler)
type_handler_iterable = F.FunctionDecorator(_type_handler_for_iterable)