Source code for elasticai.creator_plugins.combinatorial.tests.test_vhdl_nodes

import re
from collections.abc import Iterable

import pytest

from elasticai.creator import ir
from elasticai.creator.ir2vhdl import Node as VhdlNode
from elasticai.creator.ir2vhdl import Shape, factory
from elasticai.creator.ir2vhdl.language import Instance

from ..vhdl_nodes.node_factory import (
    InstanceFactoryForCombinatorial,
)


[docs] @pytest.fixture(scope="class") def node(raw_node): return InstanceFactoryForCombinatorial(raw_node)
[docs] def new_node( name: str, type: str, implementation: str, input_shape: Shape, output_shape: Shape, attributes: dict | None = None, ) -> VhdlNode: if attributes is None: attributes = {} n = factory.node( name, ir.attribute( dict(**attributes), ), type=type, implementation=implementation, input_shape=input_shape, output_shape=output_shape, ) return n
def _extract_assignments(starts_with: str, code: Iterable[str]) -> set[str]: assignments = set() extract = False for line in code: if line.startswith(starts_with): extract = True elif extract and "=>" in line: assignments.add(line.strip(",")) elif extract: extract = False return assignments
[docs] class BaseVhdlNodeTest:
[docs] def test_contains_all_signal_assignments( self, node: Instance, raw_node: VhdlNode, expected_assignments: set[str] ): code = list(line.strip() for line in node.instantiate()) assert expected_assignments == _extract_assignments("port map (", code)
[docs] def test_contains_all_generics(self, node: Instance, expected_generics: set[str]): code = (line.strip() for line in node.instantiate()) assert expected_generics == _extract_assignments("generic", code)
[docs] def test_first_line_is_correct(self, node: Instance, first_line: str): line = next(node.instantiate()) assert first_line == line.strip()
[docs] def test_can_define_signals(self, node, raw_node): signals = set(line for line in node.define_signals()) expected = set( ( f"signal d_in_{raw_node.name} : std_logic_vector({raw_node.input_shape.size()} - 1 downto 0) := (others => '0');", f"signal d_out_{raw_node.name} : std_logic_vector({raw_node.output_shape.size()} - 1 downto 0) := (others => '0');", f"signal src_valid_{raw_node.name} : std_logic := '0';", f"signal valid_{raw_node.name} : std_logic := '0';", f"signal dst_ready_{raw_node.name} : std_logic := '0';", f"signal ready_{raw_node.name} : std_logic := '0';", ) ) assert signals == expected
[docs] @pytest.mark.parametrize( ["in_shape", "out_shape", "stride", "expected"], [ (Shape(1, 1), Shape(1, 4), 1, dict(DATA_WIDTH=1, NUM_POINTS=4, SKIP=1)), (Shape(1, 2), Shape(1, 4), 1, dict(DATA_WIDTH=2, NUM_POINTS=2, SKIP=1)), (Shape(2, 3), Shape(2, 6), 1, dict(DATA_WIDTH=6, NUM_POINTS=2, SKIP=1)), (Shape(2, 3), Shape(2, 6), 3, dict(DATA_WIDTH=6, NUM_POINTS=2, SKIP=3)), ], ) def test_compute_correct_generic_map_for_shift_register( in_shape, out_shape, stride, expected ): _expected = {f"{k} => {v}" for k, v in expected.items()} result = { line.strip() for line in _extract_assignments( "generic map", InstanceFactoryForCombinatorial( new_node( name="a", type="shift_register", implementation="shift_register", input_shape=in_shape, output_shape=out_shape, attributes=dict(skip=stride), ) ).instantiate(), ) } assert result == _expected
[docs] class TestShiftRegister(BaseVhdlNodeTest):
[docs] @pytest.fixture(scope="class") def raw_node(self): conv0_channels = 2 conv1_kernel_size = 4 conv0_out_shape = Shape(conv0_channels, 2) conv1_in_shape = Shape(conv0_channels, conv1_kernel_size) n = new_node( name="a", type="shift_register", implementation="impl", input_shape=conv0_out_shape, output_shape=conv1_in_shape, ) return n
[docs] @pytest.fixture(scope="class") def expected_generics(self) -> set[str]: """conv0 out shape is (2, 2) so we have data width of 4 and conv1 in shape is (2, 4) thus we have num points 2.""" return { "NUM_POINTS => 2", "DATA_WIDTH => 4", "SKIP => 1", }
[docs] @pytest.fixture(scope="class") def expected_assignments(self) -> set[str]: return { "clk => clk", "rst => rst", "en => en", "d_in => d_in_a", "d_out => d_out_a", "src_valid => src_valid_a", "valid => valid_a", "ready => ready_a", "dst_ready => dst_ready_a", }
[docs] @pytest.fixture(scope="class") def first_line(self) -> str: return "a: entity work.impl(rtl)"
[docs] class TestSlidingWindow(BaseVhdlNodeTest):
[docs] @pytest.fixture(scope="class") def raw_node(self): return new_node( name="a", type="sliding_window", input_shape=Shape(4, 2), output_shape=Shape(4, 1), implementation="impl", attributes={"stride": 2}, )
[docs] @pytest.fixture def expected_generics(self) -> set[str]: return { "STRIDE => 8", "INPUT_WIDTH => 8", "OUTPUT_WIDTH => 4", }
[docs] @pytest.fixture def expected_assignments(self) -> set[str]: return { "clk => clk", "rst => rst", "en => en", "d_in => d_in_a", "d_out => d_out_a", "src_valid => src_valid_a", "valid => valid_a", "ready => ready_a", "dst_ready => dst_ready_a", }
[docs] @pytest.fixture def first_line(self) -> str: return "a: entity work.impl(rtl)"
[docs] def test_raise_error_for_incompatible_shapes(self): node_with_incompatible_shapes = new_node( name="a", type="sliding_window", input_shape=Shape(5, 2), output_shape=Shape(3, 2), implementation="impl", attributes={"stride": 2}, ) with pytest.raises( ValueError, match=re.escape( "Found incompatible input output shapes for sliding_window. Total input size has to be an integer multiple of total output size, but found output=Shape(depth=3, width=2) and input=Shape(depth=5, width=2)." ), ): InstanceFactoryForCombinatorial(node_with_incompatible_shapes).instantiate()
[docs] def test_warn_about_technically_compatible_but_semantically_wrong_shapes(self): node_with_incompatible_shapes = new_node( name="a", type="sliding_window", input_shape=Shape(2, 4), output_shape=Shape(1, 2), implementation="impl", attributes={"stride": 2}, ) with pytest.warns( match=re.escape( 'Detected mismatching input output shapes for sliding_window for node "a". Depth of output and input shape should usually be equal, but found output=Shape(depth=1, width=2) and input=Shape(depth=2, width=4).' ), ): InstanceFactoryForCombinatorial(node_with_incompatible_shapes).instantiate()