Source code for elasticai.creator.ir.rewriting.rewriter

from abc import abstractmethod
from collections.abc import Callable, Collection, Hashable, Iterable, Iterator, Mapping
from typing import Any, Protocol, Self

from elasticai.creator.graph import (
    Graph,
    find_all_subgraphs,
    get_rewriteable_matches,
    rewrite,
)


[docs] class Node(Protocol): @property @abstractmethod def data(self) -> dict[str, Any]: ... @property @abstractmethod def name(self) -> str: ...
[docs] class Edge(Protocol): @property @abstractmethod def data(self) -> dict[str, Any]: ... @property @abstractmethod def src(self) -> str: ... @property @abstractmethod def dst(self) -> str: ...
[docs] class ReadOnlyDataGraph[N: Node, E: Edge](Protocol): @property def graph(self) -> Graph[str]: ... @property @abstractmethod def nodes(self) -> Mapping[str, N]: ... @property @abstractmethod def edges(self) -> Mapping[tuple[str, str], E]: ...
[docs] @abstractmethod def successors(self, node: str) -> Mapping[str, N]: ...
[docs] @abstractmethod def predecessors(self, node: str) -> Mapping[str, N]: ...
@property @abstractmethod def data(self) -> dict[str, Any]: ...
[docs] class DataGraph[N: Node, E: Edge](ReadOnlyDataGraph[N, E], Protocol):
[docs] @abstractmethod def add_node(self, node: N) -> Self: ...
[docs] @abstractmethod def add_edge(self, edge: E) -> Self: ...
@property @abstractmethod def data(self) -> dict[str, Any]: ... @data.setter @abstractmethod def data(self, data: dict[str, Any]) -> None: ...
[docs] class RemappedSubImplementation[N: Node, E: Edge](ReadOnlyDataGraph[N, E]): """Use a given mapping to make the data dictionary accessible via node names from graph. This allows to access the data dictionary using the names from the graph. The provided mapping is used to map the node names of the graph to the node names in the data dictionary. This is used e.g., to initialize new nodes, that were replaced during rewriting. The mapping is assumed to be one-to-one, i.e. each node name in the graph maps to exactly one node name in the data dictionary. """ def __init__( self, mapping: dict[str, str], graph: Graph[str], data: dict[str, Any], node_fn: Callable[[str, dict[str, Any]], N], edge_fn: Callable[[str, str, dict[str, Any]], E], ): self._mapping = mapping self._inverted = dict(((v, k) for k, v in mapping.items())) self._data = data self._graph = graph self._node_constr = node_fn self._edge_constr = lambda src_dst, d: edge_fn(src_dst[0], src_dst[1], d) @property def graph(self) -> Graph[str]: return self._graph @property def data(self) -> dict[str, Any]: return self._data def _original_name(self, node: str): return self._mapping[node] def _mapped_name(self, node: str): return self._inverted[node] @property def nodes(self) -> Mapping[str, N]: return self._create_remapped_nodes(self._graph.nodes) @property def edges(self) -> Mapping[tuple[str, str], E]: remapping = { (self._original_name(src), self._original_name(dst)): (src, dst) for src, dst in self._graph.iter_edges() } return _RemappedData( remapping=remapping, data=self._data["edges"], iterable=self._graph.iter_edges(), factory_fn=self._edge_constr, )
[docs] def successors(self, node: str) -> Mapping[str, N]: return self._create_remapped_nodes( self._graph.successors[self._inverted[node]], )
def _create_remapped_nodes(self, iterable: Iterable[str]) -> Mapping[str, N]: return _RemappedData( self._mapping, self._data, iterable, factory_fn=self._node_constr )
[docs] def predecessors(self, node: str) -> Mapping[str, N]: return self._create_remapped_nodes( self._graph.predecessors[self._inverted[node]] )
class _RemappedData[K: Hashable, D](Mapping[K, D]): def __init__( self, remapping: dict[K, K], data: dict[K, Any], iterable: Iterable[K], factory_fn: Callable[[K, dict[str, Any]], D], ) -> None: self._remapping = remapping self._data = data self._iter = iterable self._create_item = factory_fn def __getitem__(self, key) -> D: return self._create_item(key, self._data[self._remapping[key]]) def __len__(self) -> int: return len(self._remapping) def __iter__(self) -> Iterator[Any]: for key in iter(self._iter): if key in self: yield key def __contains__(self, key: Any) -> bool: return key in self._remapping type Rule = Callable[[DataGraph], DataGraph]
[docs] class CompositeRule: def __init__(self, rules: Iterable[Rule]) -> None: self._rules = list(rules)
[docs] def __call__(self, impl: DataGraph) -> DataGraph: for rule in self._rules: impl = rule(impl) return impl
[docs] class Pattern(Protocol): @property @abstractmethod def graph(self) -> DataGraph: ... @property def interface(self) -> Collection[str]: ...
[docs] @abstractmethod def match(self, g: ReadOnlyDataGraph) -> list[dict[str, str]]: ...
[docs] class StdPattern(Pattern): def __init__( self, graph: DataGraph, node_constraint: Callable[[Node, Node], bool], interface: Collection[str], ): self._pattern = graph self._node_constraint = node_constraint self._interface = interface @property def graph(self) -> DataGraph: return self._pattern @property def interface(self) -> Collection[str]: return self._interface
[docs] def match(self, g: ReadOnlyDataGraph) -> list[dict[str, str]]: matches = find_all_subgraphs( graph=g.graph, pattern=self.graph.graph, node_constraint=lambda p_node, g_node: self._node_constraint( self._pattern.nodes[p_node], g.nodes[g_node], ), ) matches = list( get_rewriteable_matches( matches=matches, original=g.graph, interface_nodes=set(self._interface) ) ) return matches
[docs] class IrFactory[N: Node, E: Edge, G: DataGraph](Protocol):
[docs] @abstractmethod def node(self, name: str, data: Mapping[str, Any]) -> N: ...
[docs] @abstractmethod def edge(self, src: str, dst: str, data: Mapping[str, Any]) -> E: ...
[docs] @abstractmethod def data_graph(self) -> G: ...
[docs] class PatternRule[N: Node, E: Edge, G: DataGraph]: def __init__(self, spec: "PatternRuleSpec", ir_factory: IrFactory[N, E, G]): self._spec = spec self._ir_factory = ir_factory
[docs] def __call__(self, graph: DataGraph) -> DataGraph: return self.apply(graph)
[docs] def apply(self, graph: ReadOnlyDataGraph) -> DataGraph: self._validate_pattern(self._spec.pattern) matches = list(self._spec.pattern.match(graph)) matched_subgraphs = [ RemappedSubImplementation[N, E]( mapping=match, graph=self._spec.pattern.graph.graph, data=graph.data["nodes"], node_fn=self._ir_factory.node, edge_fn=self._ir_factory.edge, ) for match in matches ] replacements = [ self._spec.create_replacement(matched) for matched in matched_subgraphs ] self._validate_replacements(replacements) new_graph, replacement_maps = self._rewrite_raw_graph_for_matches( graph, matches, replacements ) return self._build_data_graph( new_raw_graph=new_graph, original=graph, replacement_map=replacement_maps, )
def _build_data_graph( self, new_raw_graph: Graph[str], original: ReadOnlyDataGraph, replacement_map: dict[str, tuple[str, ReadOnlyDataGraph]], ) -> DataGraph: impl = self._ir_factory.data_graph() combined_map = replacement_map | { name: (name, original) for name in original.nodes if name in new_raw_graph.nodes and name } def copy_node(name): old_name, graph = combined_map[name] impl.add_node( self._ir_factory.node(name=name, data=graph.nodes[old_name].data) ) def edge_lies_within_replacement(src_dst): src, dst = src_dst if src in replacement_map and dst in replacement_map: return True return False def copy_edge(src_dst): src, dst = src_dst if edge_lies_within_replacement(src_dst): old_src_name, graph = replacement_map[src] old_dst_name, _ = replacement_map[dst] else: old_src_name, graph = src, original old_dst_name, _ = dst, original impl.add_edge( self._ir_factory.edge( src=src, dst=dst, data=graph.edges[(old_src_name, old_dst_name)].data, ) ) for node_name in new_raw_graph.nodes: copy_node(node_name) for src, dst in new_raw_graph.iter_edges(): copy_edge((src, dst)) return impl def _validate_replacements(self, replacements: list[ReadOnlyDataGraph]) -> None: for replacement in replacements: missing_interface_nodes = set(self._spec.interface) - set( replacement.nodes.keys() ) if missing_interface_nodes: raise ValueError( f"Replacement is missing interface nodes: {missing_interface_nodes}" ) def _validate_pattern(self, pattern: Pattern) -> None: missing_interface_nodes = set(self._spec.interface) - set( pattern.graph.nodes.keys() ) if missing_interface_nodes: raise ValueError( f"Pattern Graph is missing interface nodes: {missing_interface_nodes}" ) def _rewrite_raw_graph_for_matches( self, graph: ReadOnlyDataGraph, matches: list[dict[str, str]], replacements: list[ReadOnlyDataGraph], ) -> tuple[Graph[str], dict[str, tuple[str, ReadOnlyDataGraph]]]: rewritten = graph.graph full_replacement_map = {} for match, replacement in zip(matches, replacements): rewritten, replacement_map = rewrite( # type: ignore replacement=replacement.graph, # type: ignore original=rewritten, # type: ignore match=match, lhs={x: x for x in self._spec.interface}, rhs={x: x for x in self._spec.interface}, ) for replacement_node_name, new_node_name in replacement_map.items(): full_replacement_map[new_node_name] = ( replacement_node_name, replacement, ) return rewritten, full_replacement_map
[docs] class PatternRuleSpec: """Specifies how `PatternRule` should create a new `DataGraph` from an existing one. It consists of - a pattern to match in the original DataGraph. - a replacement to apply when the pattern is matched. The `PatternRule` will use the `replacement_fn` function to create a new `DataGraph`. The function receives the matched part of the DataGraph as an argument, so you can build the replacement based on the parameters found in your matched pattern. The node names in this matched subimplementation are remapped from the original graph to the pattern graph, so you can access the data dictionary using the names from the pattern graph. E.g., if the pattern specifies a node `'conv'` you can acces the data of the original DataGraph that this node corresponds to using `matched.nodes['conv'].data`. - an interface that specifies which nodes are part of the pattern and of the replacement, nodes that are part of the interface are neither replaced nor initialized. Instead they are used to connect the pattern and the replacement. """ def __init__( self, pattern: Pattern, replacement_fn: Callable[[ReadOnlyDataGraph], ReadOnlyDataGraph], ): self.pattern = pattern self.create_replacement = replacement_fn @property def interface(self) -> Collection[str]: return self.pattern.interface