Source code for elasticai.creator_plugins.combinatorial.tests.shift_register_test
from collections.abc import Iterable
from elasticai.creator.ir2vhdl import Shape, factory
from elasticai.creator_plugins.combinatorial.vhdl_nodes import shift_register
[docs]
def test_shift_register_converts_depth_and_width_to_correct_generics():
conv0_channels = 2
conv1_kernel_size = 3
conv0_out_shape = Shape(conv0_channels, 1)
conv1_in_shape = Shape(conv0_channels, conv1_kernel_size)
n = factory.node(
"sr0",
type="shift_register",
implementation="",
input_shape=conv0_out_shape,
output_shape=conv1_in_shape,
)
sr = shift_register(n)
entity = tuple(extract_code_section(sr.instantiate(), start="generic", end=")"))
entity = set(line.strip(",") for line in entity)
expected = {
f"DATA_WIDTH => {conv0_channels}",
f"NUM_POINTS => {conv1_in_shape.width}",
}
assert entity == expected