Source code for tg_model.execution.validation

"""Static validation for compiled :class:`~tg_model.execution.dependency_graph.DependencyGraph` objects."""

from __future__ import annotations

from dataclasses import dataclass, field
from typing import Any

from tg_model.execution.configured_model import ConfiguredModel
from tg_model.execution.dependency_graph import DependencyGraph, NodeKind
from tg_model.execution.external_ops import ExternalOpsError, navigate_to_part, resolve_attribute_ref_to_slot
from tg_model.execution.value_slots import ValueSlot


@dataclass
class ValidationFailure:
    """Single validation problem (category + message + optional graph path)."""

    category: str
    message: str
    path: str | None = None


[docs] @dataclass class ValidationResult: """Aggregate result of :func:`validate_graph`.""" failures: list[ValidationFailure] = field(default_factory=list) @property def passed(self) -> bool: """True when ``failures`` is empty.""" return len(self.failures) == 0
[docs] def add(self, category: str, message: str, path: str | None = None) -> None: """Append a :class:`ValidationFailure`.""" self.failures.append(ValidationFailure(category=category, message=message, path=path))
[docs] class GraphValidationError(Exception): """Raised when :func:`validate_graph` fails before evaluation. Typical source: :meth:`tg_model.execution.configured_model.ConfiguredModel.evaluate` when ``validate=True`` and static checks do not pass. Subclasses :class:`Exception` (not :class:`BaseException`) so typical ``except Exception`` handlers catch it; use this type or inspect :attr:`result` when you need to distinguish validation from other failures. Attributes ---------- result : ValidationResult Structured failures from :func:`validate_graph`. """ def __init__(self, message: str, *, result: ValidationResult) -> None: self.result = result super().__init__(message)
[docs] def validate_graph( graph: DependencyGraph, *, configured_model: ConfiguredModel | None = None, ) -> ValidationResult: """Run static checks before evaluation (cycles, orphans, roll-ups, externals). Parameters ---------- graph : DependencyGraph Output of :func:`~tg_model.execution.graph_compiler.compile_graph`. configured_model : ConfiguredModel, optional When provided, runs :class:`~tg_model.integrations.external_compute.ValidatableExternalCompute` ``validate_binding`` hooks where implemented. Returns ------- ValidationResult Non-passing result lists structured :class:`ValidationFailure` rows (never raises for soft checks). """ result = ValidationResult() _check_cycles(graph, result) _check_orphaned_compute_nodes(graph, result) _check_empty_rollups(graph, result) _check_solve_group_integrity(graph, result) _check_duplicate_slot_assignments(graph, result) if configured_model is not None: _check_validatable_external_bindings(graph, configured_model, result) return result
def _check_cycles(graph: DependencyGraph, result: ValidationResult) -> None: try: graph.topological_order() except ValueError as e: result.add("dependency", str(e)) def _check_orphaned_compute_nodes(graph: DependencyGraph, result: ValidationResult) -> None: for node_id, node in graph.nodes.items(): if not node.is_compute_node: continue deps = graph.dependencies_of(node_id) dependents = graph.dependents_of(node_id) if len(deps) == 0 and len(dependents) == 0: result.add( "dependency", f"Compute node '{node_id}' has no dependencies and no dependents", path=node_id, ) def _check_empty_rollups(graph: DependencyGraph, result: ValidationResult) -> None: for node_id, node in graph.nodes.items(): if node.kind != NodeKind.ROLLUP_COMPUTATION: continue deps = graph.dependencies_of(node_id) if len(deps) == 0: result.add( "rollup", f"Roll-up '{node_id}' has no child dependencies — the selector resolved to nothing", path=node_id, ) def _check_solve_group_integrity(graph: DependencyGraph, result: ValidationResult) -> None: for node_id, node in graph.nodes.items(): if node.kind != NodeKind.SOLVE_GROUP: continue target_slots = node.metadata.get("target_slots", {}) if not target_slots: result.add( "solve_group", f"Solve group '{node_id}' has no target slots for unknowns", path=node_id, ) slot_ids = list(target_slots.values()) if len(slot_ids) != len(set(slot_ids)): result.add( "solve_group", f"Solve group '{node_id}' has duplicate target slot IDs", path=node_id, ) def _check_validatable_external_bindings( graph: DependencyGraph, cm: ConfiguredModel, result: ValidationResult, ) -> None: for node_id, node in graph.nodes.items(): if node.kind != NodeKind.EXTERNAL_COMPUTATION: continue binding = node.metadata.get("binding") if binding is None: continue ext = binding.external validate_fn = getattr(ext, "validate_binding", None) if validate_fn is None or not callable(validate_fn): continue owner_path = node.metadata.get("owner_path") if not owner_path: result.add("external_binding", f"Node '{node_id}' missing owner_path", path=node_id) continue try: owner = navigate_to_part(cm.root, tuple(owner_path)) input_specs: dict[str, Any] = {} for iname, ref in binding.inputs.items(): slot = resolve_attribute_ref_to_slot(ref, owner, cm) input_specs[iname] = slot.metadata.get("unit") output_specs: dict[str, Any] = {} if binding.output_routes: for key, ref in binding.output_routes.items(): out_slot = resolve_attribute_ref_to_slot(ref, owner, cm) output_specs[key] = out_slot.metadata.get("unit") else: for sid in node.metadata.get("output_slot_ids", ()): vs = cm.id_registry.get(sid) if isinstance(vs, ValueSlot): output_specs[sid] = vs.metadata.get("unit") validate_fn(input_specs=input_specs, output_specs=output_specs) except ExternalOpsError as e: result.add("external_binding", str(e), path=node_id) except Exception as e: result.add( "external_binding", f"validate_binding failed for {getattr(ext, 'name', ext)!r}: {e}", path=node_id, ) def _check_duplicate_slot_assignments(graph: DependencyGraph, result: ValidationResult) -> None: """Check that no slot_id is written to by more than one compute node.""" slot_writers: dict[str, list[str]] = {} for node_id, node in graph.nodes.items(): if node.kind == NodeKind.EXTERNAL_COMPUTATION: for sid in node.metadata.get("output_slot_ids", ()): slot_writers.setdefault(sid, []).append(node_id) continue if node.is_compute_node and node.slot_id: slot_writers.setdefault(node.slot_id, []).append(node_id) for slot_id, writers in slot_writers.items(): if len(writers) > 1: result.add( "dependency", f"Slot '{slot_id}' is written to by multiple compute nodes: {writers}", )