Source code for elasticai.creator_plugins.combinatorial.clocked_combinatorial

from collections.abc import Sequence
from itertools import chain

from elasticai.creator import graph as gr
from elasticai.creator import ir
from elasticai.creator.ir2vhdl import DataGraph, type_handler

from .combinatorial import (
    build_data_signal_connections_for_combinatorial,
    build_declarations_for_combinatorial,
    build_instantiations_for_combinatorial,
    wrap_in_architecture,
)
from .language import Port, VHDLEntity


def _is_clocked_node(node):
    return node.type in [
        "sliding_window",
        "shift_register",
    ]


[docs] @type_handler() def clocked_combinatorial( impl: DataGraph, registry: ir.Registry ) -> tuple[str, Sequence[str]]: def _iter(): input_size = impl.nodes["input"].input_shape.size() output_size = impl.nodes["output"].output_shape.size() entity = VHDLEntity( name=impl.name, port=Port( inputs=dict( clk="std_logic", en="std_logic", rst="std_logic", d_in=f"std_logic_vector({input_size} - 1 downto 0)", src_valid="std_logic", dst_ready="std_logic", ), outputs=dict( d_out=f"std_logic_vector({output_size} - 1 downto 0)", valid="std_logic", ready="std_logic", ), ), generics=dict(), ) yield from entity.generate_entity() yield "" declarations = build_declarations_for_combinatorial(impl) def build_ctrl_declaration(name) -> str: return f"signal {name} : std_logic := '0';" declarations = chain( declarations, map( build_ctrl_declaration, ("dst_ready_input", "src_valid_output", "valid_input", "ready_output"), ), ) definitions = build_data_signal_connections_for_combinatorial(impl) connected_valid_signals = [] valid_in_out_pairs = tuple(_get_valid_in_out_pairs(impl).items()) connected_valid_signals.extend( ("valid_input <= src_valid;", "ready <= dst_ready_input;") ) for dst, src in valid_in_out_pairs: connected_valid_signals.extend( (f"src_valid_{dst} <= valid_{src};", f"dst_ready_{src} <= ready_{dst};") ) connected_valid_signals.extend( ("valid <= src_valid_output;", "ready_output <= dst_ready;") ) definitions = chain( definitions, connected_valid_signals, ("", ""), build_instantiations_for_combinatorial(impl), ) definitions = chain( definitions, ) yield from wrap_in_architecture(impl.name, declarations, definitions) return impl.name, tuple(_iter())
def _get_valid_in_out_pairs(impl: DataGraph) -> dict[str, str]: is_clocked = _is_clocked_node def iterate(node: str): def pred(node: str): return impl.predecessors[node] def succ(node: str): return impl.successors[node] for node in gr.bfs_iter_up(successors=succ, predecessors=pred, start=node): yield impl.nodes[node] adjacency: dict[str, str] = {} last_clocked_node = "input" for node in filter(is_clocked, impl.nodes.values()): last_clocked_node = node.name if len(adjacency) == 0: adjacency[node.name] = "input" else: for pred in filter( is_clocked, iterate(node.name), ): adjacency[node.name] = pred.name break adjacency["output"] = last_clocked_node return adjacency