Source code for elasticai.creator.graph.vf2.state

from collections.abc import Iterator
from typing import Generic, TypeVar

from elasticai.creator.graph.graph import Graph

T = TypeVar("T")
TP = TypeVar("TP")


[docs] class State(Generic[T, TP]): def __init__(self, graph: Graph[T]) -> None: self.order: dict[T, int] = {n: i for i, n in enumerate(graph.nodes)} self.order_back: dict[int, T] = {v: k for k, v in self.order.items()} self.core: list[TP | None] = [None for _ in graph.nodes] self.in_nodes: list[int] = [0 for _ in graph.nodes] self.out_nodes: list[int] = [0 for _ in graph.nodes] self.graph: Graph[T] = graph self.current_depth = 0
[docs] def remove_pair(self, a: T, b: TP) -> None: if self.current_depth == 1: self.in_nodes[self.order[a]] = 0 self.out_nodes[self.order[a]] = 0 self.core[self.order[a]] = None
[docs] def contains_pair(self, a: T, b: TP) -> bool: return self.core[self.order[a]] == b
[docs] def contains_node(self, node: T) -> bool: return self.core[self.order[node]] is not None
[docs] def partial_nodes(self) -> set[T]: return set(self.iter_nodes())
[docs] def partial_predecessors(self, n: T) -> set[T]: return set(self.graph.predecessors[n]).intersection(self.partial_nodes())
[docs] def partial_successors(self, n: T) -> set[T]: return set(self.graph.successors[n]).intersection(self.partial_nodes())
[docs] def unseen_nodes(self) -> set[T]: depth = self.current_depth unseen = set() for node, id in self.order.items(): if not ( self.in_nodes[id] == depth or self.out_nodes[id] == depth or self.core[id] is not None ): unseen.add(node) return unseen
[docs] def in_node_successors(self, n: T) -> set[T]: nodes: set[T] = set() for n in self.graph.successors[n]: if self.in_nodes[self.order[n]] == self.current_depth: nodes.add(n) return nodes
[docs] def in_node_predecessors(self, n: T) -> set[T]: nodes: set[T] = set() for n in self.graph.predecessors[n]: if self.in_nodes[self.order[n]] == self.current_depth: nodes.add(n) return nodes
[docs] def out_node_successors(self, n: T) -> set[T]: nodes: set[T] = set() for n in self.graph.successors[n]: if self.out_nodes[self.order[n]] == self.current_depth: nodes.add(n) return nodes
[docs] def out_node_predecessors(self, n: T) -> set[T]: nodes: set[T] = set() for n in self.graph.predecessors[n]: if self.out_nodes[self.order[n]] == self.current_depth: nodes.add(n) return nodes
[docs] def unseen_successors(self, n: T) -> set[T]: return set(self.graph.successors[n]).intersection(self.unseen_nodes())
[docs] def unseen_predecessors(self, n: T) -> set[T]: return set(self.graph.predecessors[n]).intersection(self.unseen_nodes())
[docs] def iter_nodes(self) -> Iterator[T]: for n in self.order: if self.contains_node(n): yield n
[docs] def is_complete(self) -> bool: return all(n is not None for n in self.core)
[docs] def iter_matched_pairs(self) -> Iterator[tuple[T, TP]]: for node in self.graph.nodes: matched_node = self.core[self.order[node]] if matched_node is not None: yield node, matched_node
[docs] def add_pair(self, a: T, b: TP) -> None: self.core[self.order[a]] = b self.in_nodes[self.order[a]] = self.current_depth self.out_nodes[self.order[a]] = self.current_depth
[docs] def restore(self) -> None: """reset all changes introduced in the current depth. This is responsible for backtracking. """ depth = self.current_depth for node_id, match in enumerate(self.core): if self.in_nodes[node_id] == depth: self.in_nodes[node_id] = 0 if self.out_nodes[node_id] == depth: self.out_nodes[node_id] = 0 self.current_depth -= 1
[docs] def update_in_nodes(self) -> None: """Find out which nodes inside of current match are reachable from outside of current match.""" depth = self.current_depth for n in self.order: if not self.contains_node(n): for nbr in self.graph.successors[n]: if self.contains_node(nbr): self.in_nodes[self.order[n]] = depth break
[docs] def update_out_nodes(self) -> None: """Find out which nodes outside of current match are reachable from the current match.""" depth = self.current_depth for n in self.order: if not self.contains_node(n): for nbr in self.graph.predecessors[n]: if self.contains_node(nbr): self.out_nodes[self.order[n]] = depth break
[docs] def next_depth(self) -> None: self.current_depth += 1 self.update_inout_nodes()
[docs] def update_inout_nodes(self) -> None: self.update_in_nodes() self.update_out_nodes()
def __iter_nodes(self, depth: int, in_out: list[int]) -> Iterator[T]: if depth == 0: return for node, node_id in self.order.items(): added_at = in_out[node_id] if added_at >= depth: yield node
[docs] def iter_in_nodes(self) -> Iterator[T]: depth = self.current_depth yield from self.__iter_nodes(depth, self.in_nodes)
[docs] def iter_out_nodes(self) -> Iterator[T]: depth = self.current_depth yield from self.__iter_nodes(depth, self.out_nodes)