from collections.abc import Callable, Iterable, Iterator, Mapping
from typing import (
Any,
Generic,
MutableMapping,
ParamSpec,
Protocol,
Self,
TypeVar,
cast,
overload,
)
from elasticai.creator.graph import Graph
from elasticai.creator.ir.base import Attribute, IrData, read_only_field
from .core import Edge, Node
N = TypeVar("N", bound=Node, covariant=True)
E = TypeVar("E", bound=Edge, covariant=True)
StoredT = TypeVar("StoredT", bound=Attribute)
VisibleT = TypeVar("VisibleT")
P = ParamSpec("P")
[docs]
class NodeFn(Protocol[N]):
[docs]
def __call__(self, name: str, data: dict[str, Attribute]) -> N: ...
[docs]
class EdgeFn(Protocol[E]):
[docs]
def __call__(self, src: str, dst: str, data: dict[str, Attribute]) -> E: ...
[docs]
class Implementation(IrData, Generic[N, E], create_init=False):
__slots__ = (
"data",
"graph",
"_node_fn",
"_edge_fn",
"_nodes",
"_edges",
)
name: str
type: str
@overload
def __init__(
self: "Implementation[N, E]",
*,
graph: Graph[str],
edge_fn: EdgeFn[E],
node_fn: NodeFn[N],
data: dict[str, Attribute] | None = None,
) -> None: ...
@overload
def __init__(
self: "Implementation[N, Edge]",
*,
graph: Graph,
node_fn: NodeFn[N],
data: dict[str, Attribute] | None = None,
) -> None: ...
@overload
def __init__(
self: "Implementation[Node, E]",
*,
graph: Graph[str],
edge_fn: EdgeFn[E],
data: dict[str, Attribute] | None = None,
) -> None: ...
@overload
def __init__(
self: "Implementation[Node, Edge]",
*,
graph: Graph[str],
data: dict[str, Attribute] | None = None,
) -> None: ...
def __init__(
self,
*,
graph: Graph[str],
node_fn: NodeFn = Node,
edge_fn: EdgeFn = Edge,
data: dict[str, Attribute] | None = None,
) -> None:
"""Create a new Implementation. Nodes and edges from `data` will not be automatically added to the graph.
The constructor will combine `graph` and `data` so we can access nodes and edges of type `N` and `E` respectively.
And use these to manipulate the underlying graph structure.
It is the callers responsibility to keep `graph` and `data` in sync.
This might seem like a limitation at first, but in fact it is a feature.
It allows us to combine existing graph structures with new underlying data.
A common use case for this is to handle subgraphs as in the following example:
```python
subgraph = Graph()
for node in original_impl.nodes.values():
if fulfills_constraint(node):
subgraph.add_node(node.name)
for src, dst in original_impl.edges:
if fulfills_constraint(original_impl.nodes[src]) and fulfills_constraint(original_impl.nodes[dst]):
subgraph.add_edge(src, dst)
new_impl = Implementation(graph=subgraph, data=original_impl.data)
for node in new_impl.nodes.values(): # now we can iterate over the nodes of the subgraph
do_something(node)
```
"""
if data is None:
data = {"nodes": {}, "edges": {}}
if "nodes" not in data:
data["nodes"] = {}
if "edges" not in data:
data["edges"] = {}
super().__init__(data)
self.graph = graph
self._nodes: dict[str, Attribute] = cast(
dict[str, Attribute], self.data["nodes"]
)
self._edges: dict[str, dict[str, Attribute]] = cast(
dict[str, dict[str, Attribute]], self.data["edges"]
)
self._node_fn = node_fn
self._edge_fn = edge_fn
@property
def _node_data(self) -> dict[str, Attribute]:
return cast(dict[str, Attribute], self.data["nodes"])
@property
def _edge_data(self) -> MutableMapping[tuple[str, str], Attribute]:
return _NestedDictToTupleKeyAdapter(
cast(dict[str, dict[str, "Attribute"]], self.data["edges"])
)
@overload
def add_node(self, n: Node) -> Self: ...
@overload
def add_node(self, *, name: str, data: dict[str, Attribute]) -> Self: ...
[docs]
def add_node(self, *args, **kwargs) -> Self:
if len(args) + len(kwargs) == 1:
bound = _bind_args(["node"], {}, *args, **kwargs)
node = bound["node"]
return self._add_node(node.name, node.data)
else:
bound = _bind_args(["name", "data"], {"data": {}}, *args, **kwargs)
return self._add_node(bound["name"], bound["data"])
def _add_node(self, name: str, data: dict[str, Attribute]) -> Self:
self.graph.add_node(name)
self._node_data[name] = data
return self
@overload
def add_edge(self, e: Edge) -> Self: ...
@overload
def add_edge(self, src: str, dst: str, data: dict[str, Attribute]) -> Self: ...
[docs]
def add_edge(self, *args, **kwargs) -> Self:
if len(args) + len(kwargs) == 1:
if len(args) == 1:
e = args[0]
else:
e = kwargs["edge"]
return self.add_edge(e.src, e.dst, e.data)
else:
bound = _bind_args(["src", "dst", "data"], {"data": {}}, *args, **kwargs)
self.graph.add_edge(bound["src"], bound["dst"])
self._edge_data[(bound["src"], bound["dst"])] = bound["data"]
return self
[docs]
def add_nodes(self, ns: Iterable[Node]) -> Self:
for n in ns:
self.add_node(n)
return self
[docs]
def add_edges(self, es: Iterable[Edge]) -> Self:
for e in es:
self.add_edge(e)
return self
[docs]
def successors(self, node: str | Node) -> Mapping[str, N]:
if not isinstance(node, str):
node = node.name
def _successors():
return self.graph.successors[node]
return self.get_node_mapping(_successors)
[docs]
def predecessors(self, node: str | Node) -> Mapping[str, N]:
if not isinstance(node, str):
node = node.name
def _predecessors():
return self.graph.predecessors[node]
return self.get_node_mapping(_predecessors)
[docs]
def get_node_mapping(self, keys: Callable[[], Iterable[str]]) -> Mapping[str, N]:
"""Create a read-only mapping of names to Nodes in the order of `keys()`."""
return _ReadOnlyMappingInOrderAsIterable(keys, self._node_data, self._node_fn)
[docs]
def get_edge_mapping(
self, keys: Callable[[], Iterable[tuple[str, str]]]
) -> Mapping[tuple[str, str], E]:
"""Create a read-only mapping of (src, dst) pairs to Edges in the order of `keys()`."""
return _ReadOnlyMappingInOrderAsIterable(
keys, self._edge_data, self._construct_edge
)
def _construct_edge(
self, src_dst: tuple[str, str], data: dict[str, Attribute]
) -> E:
src, dst = src_dst
return self._edge_fn(src, dst, data)
@read_only_field
def nodes(self, _: dict[str, Attribute]) -> Mapping[str, N]:
return self.get_node_mapping(lambda: self.graph.nodes)
@read_only_field
def edges(self, _: dict[str, Attribute]) -> Mapping[tuple[str, str], E]:
return self.get_edge_mapping(self.graph.iter_edges)
[docs]
def as_dict(self) -> dict[str, Attribute]:
data = self.data.copy()
edges: dict[str, dict[str, Attribute]] = {}
for (src, dst), d in self._edge_data.items():
if src in edges:
edges[src][dst] = d
else:
edges[src] = {dst: d}
data["edges"] = cast(Attribute, edges)
return data
[docs]
def sync_data_with_graph(self) -> Self:
"""Removes nodes/edges from data that are not in the graph, add empty fields for new nodes/edges."""
nodes_to_remove = set()
for n in self._nodes:
if n not in self.graph.nodes:
nodes_to_remove.add(n)
for n in nodes_to_remove:
del self._nodes[n]
for n in self.graph.nodes:
if n not in self._nodes:
self._nodes[n] = {}
edges_to_remove = set()
edges_to_keep = set(self.graph.iter_edges())
for src in self._edges:
for dst in self._edges[src]:
if (src, dst) not in edges_to_keep:
edges_to_remove.add((src, dst))
for src, sink in edges_to_remove:
if src in self._edges:
del self._edges[src][sink]
if len(self._edges[src]) == 0:
del self._edges[src]
for src, sink in edges_to_keep:
if src not in self._edges:
self._edges[src] = {}
if sink not in self._edges[src]:
self._edges[src][sink] = {}
return self
[docs]
def load_from_dict(
self,
d: dict[str, Any],
) -> "Implementation[N, E]":
"""Load attributes, nodes and edges from a dictionary, that was created by `as_dict`.
Opposed to the constructor, this will add all nodes and edges to the underlying graph.
:::{important}
This will override all state of the current object.
:::
"""
self.data = d.copy()
g = self.graph
self.graph = self.graph.new()
del g
for n in self._node_data.keys():
self.graph.add_node(n)
for src, dst in self._edge_data.keys():
self.graph.add_edge(src, dst)
return self
_K = TypeVar("_K")
_V = TypeVar("_V")
_K2 = TypeVar("_K2")
class _NestedDictToTupleKeyAdapter(MutableMapping[tuple[_K, _K2], _V]):
def __init__(self, wrapped: dict[_K, dict[_K2, _V]]):
self.wrapped = wrapped
def __getitem__(self, k: tuple[_K, _K2]) -> _V:
return self.wrapped[k[0]][k[1]]
def __len__(self) -> int:
acc = 0
for v in self.wrapped.values():
acc += len(v)
return acc
def __contains__(self, k: object) -> bool:
if not isinstance(k, tuple) or len(k) != 2:
return False
return k[0] in self.wrapped and k[1] in self.wrapped[k[0]]
def __delitem__(self, key):
self.wrapped(key[0]).__delitem__(key[1])
def __iter__(self) -> Iterator[tuple[_K, _K2]]:
for k0 in self.wrapped:
for k1 in self.wrapped[k0]:
yield k0, k1
def __setitem__(self, k: tuple[_K, _K2], v: _V) -> None:
if k[0] not in self.wrapped:
self.wrapped[k[0]] = {}
self.wrapped[k[0]][k[1]] = v
class _ReadOnlyMappingInOrderAsIterable(Mapping[_K, _V]):
def __init__(
self,
iterable: Callable[[], Iterable[_K]],
d: Mapping[_K, Any],
value_constructor: Callable[[_K, Any], _V],
):
self._iterable = iterable
self._d = d
self._value_constructor = value_constructor
def __iter__(self) -> Iterator[_K]:
yield from self._iterable()
def __len__(self) -> int:
return len(self._d)
def __contains__(self, k: object) -> bool:
return k in self._d
def __getitem__(self, k: _K) -> _V:
return self._value_constructor(k, self._d[k])
def __eq__(self, other: object) -> bool:
if not isinstance(other, Mapping):
return NotImplemented
return dict(self) == dict(other)
def __repr__(self) -> str:
return f"{self.__class__.__name__}({dict(self)})"
def _bind_args(keywords: list[str], optionals: dict[str, Any], *args, **kwargs) -> dict:
bound = dict()
keys = keywords.copy()
opts = optionals.copy()
for key, arg in zip(keys, args):
bound[key] = arg
for key, value in kwargs.items():
if key in bound:
raise TypeError("failed to bind arguments: multiple values for argument")
bound[key] = value
for key, value in opts.items():
if key not in bound:
bound[key] = value
missing_args = set()
for k in keywords:
if k not in bound:
missing_args.add(k)
if len(missing_args) > 0:
raise TypeError(f"failed to bind arguments: missing argument {missing_args}")
return bound