Source code for elasticai.creator_plugins.combinatorial.vhdl_nodes.sliding_window

import warnings

from elasticai.creator.ir2vhdl import VhdlNode

from .clocked_combinatorial import ClockedInstance
from .node_factory import (
    InstanceFactoryForCombinatorial,
)


[docs] @InstanceFactoryForCombinatorial.register def sliding_window(node): return SlidingWindowNode(node)
def _check_input_output_shape_compatibility(node): input_shape = node.input_shape output_shape = node.output_shape if input_shape.size() % output_shape.size() != 0: raise ValueError( "Found incompatible input output shapes for sliding_window. Total input size has to be an integer multiple of total output size, but found output={output} and input={input}.".format( output=output_shape, input=input_shape ) ) if output_shape.depth != input_shape.depth: warnings.warn( 'Detected mismatching input output shapes for sliding_window for node "{}". Depth of output and input shape should usually be equal, but found output={} and input={}.'.format( node.name, output_shape, input_shape ), stacklevel=3, )
[docs] class SlidingWindowNode(ClockedInstance): _logic_signals_with_default_suffix = ("valid_in", "valid_out") def __init__( self, node: VhdlNode, ): _check_input_output_shape_compatibility(node) data_width = node.input_shape.size() num_points = node.output_shape.size() if ( "generic_map" in node.attributes # pyright: ignore and "stride" in node.attributes["generic_map"] # pyright: ignore ): stride = node.attributes["generic_map"]["stride"] # pyright: ignore else: stride = node.attributes["stride"] # pyright: ignore stride = stride * node.output_shape.depth super().__init__( node, input_width=data_width, output_width=num_points, generic_map=dict( stride=stride, input_width=data_width, output_width=num_points, ), )