from collections.abc import Callable
from itertools import chain, starmap
from typing import Protocol, cast
import elasticai.creator.ir.ir_v2 as ir
from elasticai.creator.graph import dfs_iter
from elasticai.creator.ir2vhdl import (
DataGraph,
IrFactory,
Node,
Registry,
Shape,
type_handler,
)
from elasticai.creator_plugins.grouped_filter import FilterParameters
class _HasAttr(Protocol):
@property
def attributes(self) -> ir.AttributeMapping: ...
class _FilterDecorator[T: _HasAttr]:
def __init__(self, decorated: T) -> None:
self._decorated = decorated
@property
def filter_parameters(self) -> FilterParameters:
if "filter_parameters" not in self.attributes:
raise Exception()
return FilterParameters.from_dict(self.attributes["filter_parameters"])
def __getattr__(self, key):
return getattr(self._decorated, key)
[docs]
class FilterNode(Node, Protocol):
@property
def filter_parameters(self) -> FilterParameters: ...
[docs]
class FilterGraph(DataGraph, Protocol):
@property
def filter_parameters(self) -> FilterParameters: ...
[docs]
def filter_node(n: Node) -> FilterNode:
return cast(FilterNode, _FilterDecorator(n))
[docs]
def filter_graph(g: DataGraph) -> FilterGraph:
return cast(FilterGraph, _FilterDecorator(g))
[docs]
def append_counter_suffix_before_construction[**P](
fn: Callable[P, Node],
) -> Callable[P, Node]:
counters: dict[str, int] = {}
def construct(*args: P.args, **kwargs: P.kwargs) -> Node:
nonlocal counters
if "name" in kwargs:
name = kwargs["name"]
kwargs.pop("name")
if not isinstance(name, str):
raise TypeError("expected `name` to be of type str")
elif isinstance(args[0], str):
name = args[0]
else:
raise TypeError("missing positional arg `name`")
count = counters.get(name, 0)
counters[name] = count + 1
new_name = f"{name}_i{count}"
args = tuple(chain((new_name,), args[1:])) # type: ignore
node = fn(*args, **kwargs)
return node
return construct
class _Sequential:
def __init__(self):
self._factory = IrFactory()
self._impl = self._factory.graph(type="clocked_combinatorial")
self._last_node: Node | None = None
self._last_filter_parameters: None | FilterParameters = None
self._counting_node_constructor = append_counter_suffix_before_construction(
self._factory.node
)
self._num_registers: int = 0
self._last_stride = 1
def add_input(self, shape: Shape):
self._impl = self._impl.add_node(
self._factory.node(
"input",
type="input",
implementation="",
output_shape=shape,
input_shape=shape,
)
)
self._last_node = self._impl.nodes["input"]
self._update_last_filter_params()
def _need_shift_register(self, params: FilterParameters) -> bool:
if self._last_node is None:
return False
consuming_more_than_last_node_produces = (
self._last_node.output_shape.width < params.kernel_size
)
return consuming_more_than_last_node_produces
def _need_sliding_window(self, params: FilterParameters) -> bool:
if self._last_node is None:
return False
consuming_less_than_last_node_produces = (
self._last_node.output_shape.width > params.kernel_size
)
return consuming_less_than_last_node_produces
def filter(self, n: Node):
node = filter_node(n)
attributes = n.attributes
params = node.filter_parameters
if "top_stride" not in self._impl.attributes:
self._impl = self._impl.with_attributes(
self._impl.attributes.update_path(
("top_stride",), params.in_channels * params.stride
)
)
if self._need_shift_register(params):
old_params = self._last_filter_parameters
assert old_params is not None
self.strided_shift_register(
output_shape=(
params.in_channels,
params.kernel_size,
),
stride=old_params.stride,
)
elif self._need_sliding_window(params):
self._sliding_window(params)
elif self._last_node is not None:
pass
else:
raise ValueError("expected last node to be not None")
self._append_node(
name=n.name,
type="unclocked_combinatorial",
implementation=n.implementation,
output_shape=Shape(attributes["filter_parameters"]["out_channels"], 1),
attributes=attributes,
node_fn=self._factory.node,
)
def _sliding_window(self, params: FilterParameters) -> None:
self._append_static(
"sliding_window",
"sliding_window",
output_shape=Shape(
params.in_channels,
params.kernel_size,
),
)
def _update_last_filter_params(self):
new_node = self._last_node
assert new_node is not None
if (
"filter_parameters" in new_node.attributes
and self._last_filter_parameters is not None
):
params = filter_node(new_node).filter_parameters
if params.kernel_size == 1:
self._last_filter_parameters = FilterParameters(
kernel_size=params.kernel_size,
in_channels=params.in_channels,
out_channels=params.out_channels,
groups=params.groups,
stride=self._last_filter_parameters.stride * params.stride,
input_size=params.input_size,
output_size=params.output_size,
)
else:
self._last_filter_parameters = params
elif self._last_filter_parameters is None:
self._last_filter_parameters = FilterParameters(
kernel_size=1,
in_channels=new_node.input_shape.depth,
out_channels=new_node.output_shape.depth,
)
def _append_node(
self,
name: str,
output_shape: Shape,
type: str,
implementation: str,
node_fn,
attributes: ir.AttributeMapping | None = None,
):
old_node = self._last_node
if old_node is None:
raise Exception("no input node")
input_shape = old_node.output_shape
if attributes is not None:
new_node = node_fn(
name,
attributes,
input_shape=input_shape,
output_shape=output_shape,
type=type,
implementation=implementation,
)
else:
new_node = node_fn(
name,
input_shape=input_shape,
output_shape=output_shape,
type=type,
implementation=implementation,
)
self._impl = self._impl.add_node(new_node)
self._last_node = new_node
self._update_last_filter_params()
if old_node is not None:
self._impl = self._impl.add_edge(
self._factory.edge(
old_node.name,
new_node.name,
src_dst_indices=tuple(),
)
)
def _append_static(
self, name: str, implementation: str, output_shape: Shape, **kwargs
):
self._append_node(
name=name,
output_shape=output_shape,
implementation=implementation,
type=implementation,
attributes=ir.attribute(kwargs),
node_fn=self._counting_node_constructor,
)
def shift_register(self, name: str, output_shape: Shape):
self._append_static(
name=name,
output_shape=output_shape,
implementation="shift_register",
)
def strided_shift_register(self, output_shape: tuple[int, int], stride: int):
self._num_registers += 1
if stride > 1:
self._append_static(
name="striding_shift_register",
output_shape=Shape(*output_shape),
implementation="striding_shift_register",
generic_map=dict(stride=stride),
)
else:
self._append_static(
name="shift_register",
output_shape=Shape(*output_shape),
implementation="shift_register",
generic_map=dict(),
)
def input(self, impl: DataGraph, input_node: str) -> None:
input_shape = self._determine_required_input_shape(impl, input_node)
self._impl = self._impl.with_attributes(
self._impl.attributes.new_with(top_kernel_size=input_shape.size())
)
self.add_input(input_shape)
def _determine_required_input_shape(
self, impl: DataGraph, input_node: str
) -> Shape:
first_node_after_input = impl.nodes[tuple(impl.successors[input_node])[0]]
match first_node_after_input.type:
case "filter":
n = filter_node(first_node_after_input)
return Shape(
n.filter_parameters.in_channels, n.filter_parameters.kernel_size
)
case _:
return impl.nodes[input_node].output_shape
def set_runtime_input_shape(self, s: Shape) -> None:
self._impl = self._impl.with_attributes(
self._impl.attributes.new_with(runtime_input_shape=s.to_tuple())
)
def set_runtime_output_shape(self, s: Shape) -> None:
self._impl = self._impl.with_attributes(
self._impl.attributes.new_with(runtime_output_shape=s.to_tuple())
)
def get_impl(self) -> DataGraph:
return self._impl
def output(self, n):
assert self._last_node is not None
if self._last_node.output_shape.width < n.input_shape.width:
self.shift_register("shift_register", n.output_shape)
self._append_node(
name="output",
output_shape=self._last_node.output_shape,
type="output",
implementation="",
node_fn=self._factory.node,
)
_factory = IrFactory()
[docs]
@type_handler()
def sequential(
impl: ir.DataGraph[ir.Node, ir.Edge],
registry: ir.Registry[ir.DataGraph[ir.Node, ir.Edge]],
) -> tuple[DataGraph, Registry]:
seq = _Sequential()
impl = _factory.graph(other=impl)
def iter_nodes():
input_node = None
for input_node in impl.nodes.values():
if input_node.type == "input":
break
if input_node is None:
raise Exception(
f"passed graph has {len(impl.nodes)} nodes none of them of type 'input_node'"
)
def succ(node):
return impl.successors[node]
yield from dfs_iter(succ, input_node.name)
for n in map(lambda n: impl.nodes[n], iter_nodes()):
match n.type:
case "filter":
seq.filter(n)
case "input":
seq.set_runtime_input_shape(n.input_shape)
seq.input(impl, n.name)
case "output":
seq.set_runtime_output_shape(n.output_shape)
seq.output(n)
case _:
raise Exception(
f"Can't handle unknown type {n.type} during generation of time multiplexed sequential"
)
def to_our_graph(k, g):
return k, _factory.graph(other=g)
new_registry = ir.Registry(starmap(to_our_graph, registry.items()))
return seq.get_impl(), new_registry
[docs]
@type_handler()
def network(
impl: ir.DataGraph[ir.Node, ir.Edge],
registry: ir.Registry[ir.DataGraph[ir.Node, ir.Edge]],
) -> tuple[DataGraph, Registry]:
network, registry = sequential(_factory.graph(other=impl), registry)
network = network.with_attributes(network.attributes | dict(name="network"))
# network.attributes["top_kernel_size"]
# network.attributes["top_stride"]
input_shape = network.nodes["input"].input_shape
output_shape = network.nodes["output"].output_shape
input_width, input_depth = input_shape
output_width, output_depth = output_shape
skeleton_attrs = {}
skeleton_attrs["generic_map"] = {
"DATA_IN_WIDTH": str(input_width),
"DATA_IN_DEPTH": str(input_depth),
"DATA_OUT_WIDTH": str(output_width),
"DATA_OUT_DEPTH": str(output_depth),
}
registry = registry | dict(
skeleton=_factory.graph(ir.attribute(skeleton_attrs), type="skeleton"),
buffered_network_wrapper=_factory.graph(type="buffered_network_wrapper"),
)
return network, registry