Source code for tg_model.execution.dependency_graph

"""Bipartite dependency graph for one configured model (:class:`DependencyGraph`).

Edges point **from dependency → dependent** (inputs must resolve before consumers).
:class:`~tg_model.execution.evaluator.Evaluator` consumes :meth:`DependencyGraph.topological_order`.
"""

from __future__ import annotations

from collections.abc import Iterable
from enum import Enum
from typing import Any


[docs] class NodeKind(Enum): """Compute vs value classification for dependency nodes.""" INPUT_PARAMETER = "input_parameter" ATTRIBUTE_VALUE = "attribute_value" CONSTRAINT_RESULT = "constraint_result" LOCAL_EXPRESSION = "local_expression" ROLLUP_COMPUTATION = "rollup_computation" SOLVE_GROUP = "solve_group" CONSTRAINT_CHECK = "constraint_check" EXTERNAL_COMPUTATION = "external_computation"
[docs] class DependencyNode: """Single node in a :class:`DependencyGraph` (value or compute kind).""" __slots__ = ("kind", "metadata", "node_id", "slot_id") def __init__( self, node_id: str, kind: NodeKind, slot_id: str | None = None, metadata: dict[str, Any] | None = None, ) -> None: self.node_id = node_id self.kind = kind self.slot_id = slot_id self.metadata = metadata or {} @property def is_value_node(self) -> bool: return self.kind in ( NodeKind.INPUT_PARAMETER, NodeKind.ATTRIBUTE_VALUE, NodeKind.CONSTRAINT_RESULT, ) @property def is_compute_node(self) -> bool: return self.kind in ( NodeKind.LOCAL_EXPRESSION, NodeKind.ROLLUP_COMPUTATION, NodeKind.SOLVE_GROUP, NodeKind.CONSTRAINT_CHECK, NodeKind.EXTERNAL_COMPUTATION, ) def __repr__(self) -> str: return f"<DependencyNode: {self.node_id} ({self.kind.value})>"
[docs] class DependencyGraph: """Directed graph of value and compute nodes for one compile of a configured model. Notes ----- ``add_edge(from_id, to_id)`` means ``from_id`` must be satisfied before ``to_id``. """ def __init__(self) -> None: self._nodes: dict[str, DependencyNode] = {} self._edges: list[tuple[str, str]] = [] self._dependents: dict[str, list[str]] = {} self._dependencies: dict[str, list[str]] = {}
[docs] def add_node(self, node: DependencyNode) -> None: self._nodes[node.node_id] = node if node.node_id not in self._dependents: self._dependents[node.node_id] = [] if node.node_id not in self._dependencies: self._dependencies[node.node_id] = []
[docs] def add_edge(self, from_id: str, to_id: str) -> None: """Add dependency edge (``from_id`` before ``to_id``).""" self._edges.append((from_id, to_id)) self._dependents.setdefault(from_id, []).append(to_id) self._dependencies.setdefault(to_id, []).append(from_id)
[docs] def get_node(self, node_id: str) -> DependencyNode: """Return node by id. Raises ------ KeyError If unknown. """ return self._nodes[node_id]
@property def nodes(self) -> dict[str, DependencyNode]: return dict(self._nodes) @property def edges(self) -> list[tuple[str, str]]: return list(self._edges)
[docs] def dependencies_of(self, node_id: str) -> list[str]: return list(self._dependencies.get(node_id, []))
[docs] def dependents_of(self, node_id: str) -> list[str]: return list(self._dependents.get(node_id, []))
[docs] def topological_order(self) -> list[str]: """Deterministic topological sort (stable tie-break by sorted ready queue). Returns ------- list[str] Node ids in evaluation order. Raises ------ ValueError If a cycle remains (not all nodes scheduled). """ in_degree: dict[str, int] = {nid: 0 for nid in self._nodes} for _, to_id in self._edges: in_degree[to_id] = in_degree.get(to_id, 0) + 1 ready = [nid for nid, deg in in_degree.items() if deg == 0] order: list[str] = [] while ready: ready.sort() current = ready.pop(0) order.append(current) for dep in self._dependents.get(current, []): in_degree[dep] -= 1 if in_degree[dep] == 0: ready.append(dep) if len(order) != len(self._nodes): remaining = set(self._nodes) - set(order) raise ValueError(f"Dependency cycle detected involving: {remaining}") return order
[docs] def dependency_closure(self, seeds: Iterable[str]) -> set[str]: """Walk upstream dependencies from ``seeds``. Raises ------ KeyError If a seed is not a known node id. """ seen: set[str] = set() stack = list(seeds) while stack: nid = stack.pop() if nid in seen: continue if nid not in self._nodes: raise KeyError(f"Unknown node id in dependency_closure seeds: {nid!r}") seen.add(nid) stack.extend(self.dependencies_of(nid)) return seen
[docs] def dependent_closure(self, seeds: Iterable[str]) -> set[str]: """Walk downstream dependents from ``seeds``. Raises ------ KeyError If a seed is not a known node id. """ seen: set[str] = set() stack = list(seeds) while stack: nid = stack.pop() if nid in seen: continue if nid not in self._nodes: raise KeyError(f"Unknown node id in dependent_closure seeds: {nid!r}") seen.add(nid) stack.extend(self.dependents_of(nid)) return seen
[docs] def induced_subgraph(self, node_ids: set[str]) -> DependencyGraph: """Copy restricted to ``node_ids`` (internal edges only). Raises ------ KeyError If any requested id is unknown. """ unknown = node_ids - set(self._nodes) if unknown: raise KeyError(f"induced_subgraph: unknown nodes {unknown!r}") sub = DependencyGraph() for nid in node_ids: node = self._nodes[nid] sub.add_node( DependencyNode( node.node_id, node.kind, slot_id=node.slot_id, metadata=dict(node.metadata), ) ) for fr, to in self._edges: if fr in node_ids and to in node_ids: sub.add_edge(fr, to) return sub