Source code for elasticai.creator_plugins.combinatorial.clocked_combinatorial

from collections.abc import Sequence
from itertools import chain

from elasticai.creator.ir2vhdl import Implementation, type_handler

from .combinatorial import (
    build_data_signal_connections_for_combinatorial,
    build_declarations_for_combinatorial,
    build_instantiations_for_combinatorial,
    wrap_in_architecture,
)
from .language import Port, VHDLEntity


[docs] @type_handler def clocked_combinatorial(impl: Implementation) -> tuple[str, Sequence[str]]: def _iter(): input_size = impl.nodes["input"].input_shape.size() output_size = impl.nodes["output"].output_shape.size() entity = VHDLEntity( name=impl.name, port=Port( inputs=dict( clk="std_logic", rst="std_logic", d_in=f"std_logic_vector({input_size} - 1 downto 0)", valid_in="std_logic", ), outputs=dict( d_out=f"std_logic_vector({output_size} - 1 downto 0)", valid_out="std_logic", ), ), generics=dict(), ) yield from entity.generate_entity() yield "" declarations = build_declarations_for_combinatorial(impl) declarations = chain( declarations, ( "signal valid_out_input : std_logic := '0';", "signal valid_in_output : std_logic := '0';", ), ) definitions = build_data_signal_connections_for_combinatorial(impl) connected_valid_signals = [] valid_in_out_pairs = tuple(_get_valid_in_out_pairs(impl).items()) connected_valid_signals.extend(("valid_out_input <= valid_in;",)) for dst, src in valid_in_out_pairs: connected_valid_signals.append(f"valid_in_{dst} <= valid_out_{src};") connected_valid_signals.append("valid_out <= valid_in_output;") definitions = chain( definitions, connected_valid_signals, ("", ""), build_instantiations_for_combinatorial(impl), ) definitions = chain( definitions, ) yield from wrap_in_architecture(impl.name, declarations, definitions) return impl.name, tuple(_iter())
def _get_valid_in_out_pairs(impl: Implementation) -> dict[str, str]: def is_clocked(node): return node.type in [ "striding_shift_register", "sliding_window", "shift_register", ] adjacency: dict[str, str] = {} last_clocked_node = "input" for node in filter(is_clocked, impl.nodes.values()): last_clocked_node = node.name if len(adjacency) == 0: adjacency[node.name] = "input" else: for pred in filter( is_clocked, impl.iterate_bfs_up_from(node.name), ): adjacency[node.name] = pred.name break adjacency["output"] = last_clocked_node return adjacency