Source code for elasticai.creator.hdl_ir

import operator
from collections.abc import Callable, Iterable, Sequence
from functools import reduce
from typing import Any, Protocol, TypeGuard, cast, overload

import elasticai.creator.ir.ir_v2 as ir

type ShapeTuple = tuple[int] | tuple[int, int] | tuple[int, int, int]


[docs] def is_shape_tuple(values) -> TypeGuard[ShapeTuple]: max_num_values = 3 return len(values) <= max_num_values
[docs] class Shape: @overload def __init__(self, width: int, /) -> None: ... @overload def __init__(self, depth: int, width: int, /) -> None: ... @overload def __init__(self, depth: int, width: int, height: int, /) -> None: ... def __init__(self, *values: int) -> None: """values are interpreted as one of the following: - width - depth, width - depth, width, height Usually width is kernel_size, depth is channels """ if is_shape_tuple(values): self._data = values else: raise TypeError(f"taking at most three ints, given {values}")
[docs] @classmethod def from_tuple(cls, values: ShapeTuple | list[int]) -> "Shape": return cls(*values)
[docs] def to_tuple(self) -> ShapeTuple: return self._data
[docs] def to_list(self) -> list[int]: return list(self.to_tuple())
[docs] def __getitem__(self, item): return self._data[item]
[docs] def size(self) -> int: return reduce(operator.mul, self._data, 1)
[docs] def ndim(self) -> int: return len(self._data)
@property def depth(self) -> int: if len(self._data) > 1: return self._data[0] return 1
[docs] def __eq__(self, other): if isinstance(other, tuple): return self._data == other if isinstance(other, Shape): return self._data == other._data return False
@property def width(self) -> int: if len(self._data) > 1: return self._data[1] # ty: ignore else: return self._data[0] @property def height(self) -> int: if len(self._data) > 2: return self._data[2] # ty: ignore return 1
[docs] def __repr__(self) -> str: match self._data: case (width,): return f"Shape({width=})" case (depth, width): return f"Shape({depth=}, {width=})" case (depth, width, height): return f"Shape({depth=}, {width=}, {height=})" case _: return f"Shape({self._data})"
[docs] class Node(ir.Node, Protocol): """Extending ir.core.Node to an hdl specific node. This node contains all knowledge that we need to create and use an instance of an hdl component. However, this becomes a little bit complicated because vhdl differentiates between Attributes: implementation:: The name of the implementation, e.g., entity name in vhdl or module name for verilog, will be used to derive the architecture name. E.g., if the implementation is `"adder"`, we will instantiate the entity `work.adder(rtl)`. CAUTION: This behaviour is subject to change. Future versions might require the full entity name """ @property def implementation(self) -> str: ... @property def input_shape(self) -> Shape: ... @property def output_shape(self) -> Shape: ...
def _type_check[T](item: Any, t: type[T]) -> T: if isinstance(item, t): return item else: raise TypeError(f"Expected type {t} but found {type(item)}")
[docs] class NodeImpl(ir.NodeImpl): @staticmethod def _shape_check(item: Any) -> tuple[int, ...]: if is_shape_tuple(item): return item else: raise TypeError( f"expected input_shape to be of type tuple[int, ...] but found {type(item)}" ) @property def implementation(self) -> str: return _type_check(self.attributes.get("implementation", "<none>"), str) @property def input_shape(self) -> Shape: shape = self.attributes.get("input_shape", tuple()) return Shape(*self._shape_check(shape)) @property def output_shape(self) -> Shape: shape = self.attributes.get("output_shape", tuple()) return Shape(*self._shape_check(shape))
[docs] class Edge(ir.Edge, Protocol): @property def src_dst_indices(self) -> tuple[tuple[int | str, int | str], ...]: ...
[docs] class EdgeImpl(ir.EdgeImpl): @property def src_dst_indices(self) -> tuple[tuple[int | str, int | str], ...]: indices = self.attributes.get("src_dst_indices", tuple()) return _type_check(indices, tuple)
[docs] class DataGraph(ir.DataGraph[Node, Edge], Protocol): @property def name(self) -> str: ... @property def type(self) -> str: ...
[docs] class DataGraphImpl(ir.DataGraphImpl[Node, Edge]): @property def name(self) -> str: return _type_check(self.attributes.get("name", None), str) @property def type(self) -> str: return _type_check(self.attributes.get("type", "<undefined>"), str)
[docs] class IrFactory:
[docs] def node( self, name: str, attributes: ir.AttributeMapping = ir.AttributeMapping(), /, type: str | None = None, input_shape: Shape | None = None, output_shape: Shape | None = None, implementation: str | None = None, ) -> Node: extra_attributes: dict[str, ShapeTuple | str] = {} for key, item in (("input_shape", input_shape), ("output_shape", output_shape)): if item is not None: extra_attributes[key] = item.to_tuple() if implementation is not None: extra_attributes["implementation"] = implementation if type is not None: extra_attributes["type"] = type if len(extra_attributes) > 0: return NodeImpl(name, attributes | extra_attributes) return NodeImpl(name, attributes)
[docs] def edge( self, src: str, dst: str, attributes: ir.AttributeMapping = ir.AttributeMapping(), /, src_dst_indices: Iterable[tuple[int, int]] | tuple[str, str] = tuple(), ) -> Edge: if ( isinstance(src_dst_indices, tuple) and len(src_dst_indices) > 0 and isinstance(src_dst_indices[0], str) ): indices = src_dst_indices else: indices = tuple(cast(Iterable[tuple[int, int]], src_dst_indices)) if len(indices) > 0: attributes = attributes | dict(src_dst_indices=indices) return EdgeImpl(src, dst, attributes)
[docs] def graph( self, attributes: ir.AttributeMapping = ir.AttributeMapping(), *, type: str | None = None, name: str | None = None, other: ir.DataGraph[ir.Node, ir.Edge] | None = None, ) -> DataGraph: if other is not None: _graph = other.graph attributes = other.attributes | attributes node_attributes = other.node_attributes else: node_attributes = ir.AttributeMapping() _graph = ir.GraphImpl(lambda: ir.AttributeMapping()) if type is not None: attributes = attributes.new_with(type=type) if name is not None: attributes = attributes.new_with(name=name) return DataGraphImpl(self, attributes, _graph, node_attributes)
def _check_and_get_name_fn(name: str | None, fn: Callable) -> str: if name is None: if hasattr(fn, "__name__") and isinstance(fn.__name__, str): return fn.__name__ else: raise Exception(f"you need to specify name explicitly for {fn}") return name type Code = tuple[str, Sequence[str]] type Registry = ir.Registry[DataGraph] type TypeHandler = Callable[[DataGraph, Registry], Iterable[Code]] type NonIterableTypeHandler = Callable[[DataGraph, Registry], Code]