Source code for elasticai.creator_plugins.time_multiplexed_sequential.src

from collections.abc import Callable
from itertools import chain, starmap
from typing import Protocol, cast

import elasticai.creator.ir.ir_v2 as ir
from elasticai.creator.graph import dfs_iter
from elasticai.creator.ir2vhdl import (
    DataGraph,
    IrFactory,
    Node,
    Registry,
    Shape,
    type_handler,
)
from elasticai.creator_plugins.grouped_filter import FilterParameters


class _HasAttr(Protocol):
    @property
    def attributes(self) -> ir.AttributeMapping: ...


class _FilterDecorator[T: _HasAttr]:
    def __init__(self, decorated: T) -> None:
        self._decorated = decorated

    @property
    def filter_parameters(self) -> FilterParameters:
        if "filter_parameters" not in self.attributes:
            raise Exception()
        return FilterParameters.from_dict(self.attributes["filter_parameters"])

    def __getattr__(self, key):
        return getattr(self._decorated, key)


[docs] class FilterNode(Node, Protocol): @property def filter_parameters(self) -> FilterParameters: ...
[docs] class FilterGraph(DataGraph, Protocol): @property def filter_parameters(self) -> FilterParameters: ...
[docs] def filter_node(n: Node) -> FilterNode: return cast(FilterNode, _FilterDecorator(n))
[docs] def filter_graph(g: DataGraph) -> FilterGraph: return cast(FilterGraph, _FilterDecorator(g))
[docs] def append_counter_suffix_before_construction[**P]( fn: Callable[P, Node], ) -> Callable[P, Node]: counters: dict[str, int] = {} def construct(*args: P.args, **kwargs: P.kwargs) -> Node: nonlocal counters if "name" in kwargs: name = kwargs["name"] kwargs.pop("name") if not isinstance(name, str): raise TypeError("expected `name` to be of type str") elif isinstance(args[0], str): name = args[0] else: raise TypeError("missing positional arg `name`") count = counters.get(name, 0) counters[name] = count + 1 new_name = f"{name}_i{count}" args = tuple(chain((new_name,), args[1:])) # type: ignore node = fn(*args, **kwargs) return node return construct
class _Sequential: def __init__(self): self._factory = IrFactory() self._impl = self._factory.graph(type="clocked_combinatorial") self._last_node: Node | None = None self._last_filter_parameters: None | FilterParameters = None self._counting_node_constructor = append_counter_suffix_before_construction( self._factory.node ) self._num_registers: int = 0 self._last_stride = 1 def add_input(self, shape: Shape): self._impl = self._impl.add_node( self._factory.node( "input", type="input", implementation="", output_shape=shape, input_shape=shape, ) ) self._last_node = self._impl.nodes["input"] self._update_last_filter_params() def _need_shift_register(self, params: FilterParameters) -> bool: if self._last_node is None: return False consuming_more_than_last_node_produces = ( self._last_node.output_shape.width < params.kernel_size ) return consuming_more_than_last_node_produces def _need_sliding_window(self, params: FilterParameters) -> bool: if self._last_node is None: return False consuming_less_than_last_node_produces = ( self._last_node.output_shape.width > params.kernel_size ) return consuming_less_than_last_node_produces def filter(self, n: Node): node = filter_node(n) attributes = n.attributes params = node.filter_parameters if "top_stride" not in self._impl.attributes: self._impl = self._impl.with_attributes( self._impl.attributes.update_path( ("top_stride",), params.in_channels * params.stride ) ) if self._need_shift_register(params): old_params = self._last_filter_parameters assert old_params is not None self.strided_shift_register( output_shape=( params.in_channels, params.kernel_size, ), stride=old_params.stride, ) elif self._need_sliding_window(params): self._sliding_window(params) elif self._last_node is not None: pass else: raise ValueError("expected last node to be not None") self._append_node( name=n.name, type="unclocked_combinatorial", implementation=n.implementation, output_shape=Shape(attributes["filter_parameters"]["out_channels"], 1), attributes=attributes, node_fn=self._factory.node, ) def _sliding_window(self, params: FilterParameters) -> None: self._append_static( "sliding_window", "sliding_window", output_shape=Shape( params.in_channels, params.kernel_size, ), ) def _update_last_filter_params(self): new_node = self._last_node assert new_node is not None if ( "filter_parameters" in new_node.attributes and self._last_filter_parameters is not None ): params = filter_node(new_node).filter_parameters if params.kernel_size == 1: self._last_filter_parameters = FilterParameters( kernel_size=params.kernel_size, in_channels=params.in_channels, out_channels=params.out_channels, groups=params.groups, stride=self._last_filter_parameters.stride * params.stride, input_size=params.input_size, output_size=params.output_size, ) else: self._last_filter_parameters = params elif self._last_filter_parameters is None: self._last_filter_parameters = FilterParameters( kernel_size=1, in_channels=new_node.input_shape.depth, out_channels=new_node.output_shape.depth, ) def _append_node( self, name: str, output_shape: Shape, type: str, implementation: str, node_fn, attributes: ir.AttributeMapping | None = None, ): old_node = self._last_node if old_node is None: raise Exception("no input node") input_shape = old_node.output_shape if attributes is not None: new_node = node_fn( name, attributes, input_shape=input_shape, output_shape=output_shape, type=type, implementation=implementation, ) else: new_node = node_fn( name, input_shape=input_shape, output_shape=output_shape, type=type, implementation=implementation, ) self._impl = self._impl.add_node(new_node) self._last_node = new_node self._update_last_filter_params() if old_node is not None: self._impl = self._impl.add_edge( self._factory.edge( old_node.name, new_node.name, src_dst_indices=tuple(), ) ) def _append_static( self, name: str, implementation: str, output_shape: Shape, **kwargs ): self._append_node( name=name, output_shape=output_shape, implementation=implementation, type=implementation, attributes=ir.attribute(kwargs), node_fn=self._counting_node_constructor, ) def shift_register(self, name: str, output_shape: Shape): self._append_static( name=name, output_shape=output_shape, implementation="shift_register", ) def strided_shift_register(self, output_shape: tuple[int, int], stride: int): self._num_registers += 1 if stride > 1: self._append_static( name="striding_shift_register", output_shape=Shape(*output_shape), implementation="striding_shift_register", generic_map=dict(stride=stride), ) else: self._append_static( name="shift_register", output_shape=Shape(*output_shape), implementation="shift_register", generic_map=dict(), ) def input(self, impl: DataGraph, input_node: str) -> None: input_shape = self._determine_required_input_shape(impl, input_node) self._impl = self._impl.with_attributes( self._impl.attributes.new_with(top_kernel_size=input_shape.size()) ) self.add_input(input_shape) def _determine_required_input_shape( self, impl: DataGraph, input_node: str ) -> Shape: first_node_after_input = impl.nodes[tuple(impl.successors[input_node])[0]] match first_node_after_input.type: case "filter": n = filter_node(first_node_after_input) return Shape( n.filter_parameters.in_channels, n.filter_parameters.kernel_size ) case _: return impl.nodes[input_node].output_shape def set_runtime_input_shape(self, s: Shape) -> None: self._impl = self._impl.with_attributes( self._impl.attributes.new_with(runtime_input_shape=s.to_tuple()) ) def set_runtime_output_shape(self, s: Shape) -> None: self._impl = self._impl.with_attributes( self._impl.attributes.new_with(runtime_output_shape=s.to_tuple()) ) def get_impl(self) -> DataGraph: return self._impl def output(self, n): assert self._last_node is not None if self._last_node.output_shape.width < n.input_shape.width: self.shift_register("shift_register", n.output_shape) self._append_node( name="output", output_shape=self._last_node.output_shape, type="output", implementation="", node_fn=self._factory.node, ) _factory = IrFactory()
[docs] @type_handler() def sequential( impl: ir.DataGraph[ir.Node, ir.Edge], registry: ir.Registry[ir.DataGraph[ir.Node, ir.Edge]], ) -> tuple[DataGraph, Registry]: seq = _Sequential() impl = _factory.graph(other=impl) def iter_nodes(): input_node = None for input_node in impl.nodes.values(): if input_node.type == "input": break if input_node is None: raise Exception( f"passed graph has {len(impl.nodes)} nodes none of them of type 'input_node'" ) def succ(node): return impl.successors[node] yield from dfs_iter(succ, input_node.name) for n in map(lambda n: impl.nodes[n], iter_nodes()): match n.type: case "filter": seq.filter(n) case "input": seq.set_runtime_input_shape(n.input_shape) seq.input(impl, n.name) case "output": seq.set_runtime_output_shape(n.output_shape) seq.output(n) case _: raise Exception( f"Can't handle unknown type {n.type} during generation of time multiplexed sequential" ) def to_our_graph(k, g): return k, _factory.graph(other=g) new_registry = ir.Registry(starmap(to_our_graph, registry.items())) return seq.get_impl(), new_registry
[docs] @type_handler() def network( impl: ir.DataGraph[ir.Node, ir.Edge], registry: ir.Registry[ir.DataGraph[ir.Node, ir.Edge]], ) -> tuple[DataGraph, Registry]: network, registry = sequential(_factory.graph(other=impl), registry) network = network.with_attributes(network.attributes | dict(name="network")) # network.attributes["top_kernel_size"] # network.attributes["top_stride"] input_shape = network.nodes["input"].input_shape output_shape = network.nodes["output"].output_shape input_width, input_depth = input_shape output_width, output_depth = output_shape skeleton_attrs = {} skeleton_attrs["generic_map"] = { "DATA_IN_WIDTH": str(input_width), "DATA_IN_DEPTH": str(input_depth), "DATA_OUT_WIDTH": str(output_width), "DATA_OUT_DEPTH": str(output_depth), } registry = registry | dict( skeleton=_factory.graph(ir.attribute(skeleton_attrs), type="skeleton"), buffered_network_wrapper=_factory.graph(type="buffered_network_wrapper"), ) return network, registry