"""Parameter sweeps over a fixed dependency graph (Phase 5).
Reuses :class:`tg_model.execution.evaluator.Evaluator` for each sample; every
sample gets a fresh :class:`tg_model.execution.run_context.RunContext`.
**Coherence:** Pass ``configured_model`` whenever you have it; the library then
verifies sweep :class:`ValueSlot` handles match ``compile_graph(configured_model)``.
**Throughput:** Samples run sequentially. This is not a parallel study runner.
**Pruning:** ``prune_to_slots`` evaluates an *upstream-closed* subgraph only.
Constraint (and other) nodes outside that closure are *not* executed —
``RunResult.constraint_results`` may be empty. Do not treat a pruned sweep as a
compliance run unless you know what you excluded.
"""
from __future__ import annotations
import itertools
from collections.abc import Callable, Mapping, Sequence
from dataclasses import dataclass
from typing import Any
from tg_model.analysis._coherence import assert_slots_align_with_graph, plan_sweep_axes
from tg_model.execution.configured_model import ConfiguredModel
from tg_model.execution.dependency_graph import DependencyGraph
from tg_model.execution.evaluator import Evaluator, RunResult
from tg_model.execution.run_context import RunContext
from tg_model.execution.value_slots import ValueSlot
[docs]
@dataclass(frozen=True)
class SweepRecord:
"""One row from :func:`sweep` / :func:`sweep_async`."""
index: int
inputs: dict[str, Any]
result: RunResult
def _value_node_id(slot: ValueSlot) -> str:
return f"val:{slot.path_string}"
def _prepare_pruned(
graph: DependencyGraph,
handlers: dict[str, Any],
prune_to_slots: Sequence[ValueSlot] | None,
) -> tuple[DependencyGraph, dict[str, Any]]:
if not prune_to_slots:
return graph, handlers
seeds = [_value_node_id(s) for s in prune_to_slots]
for vid in seeds:
if vid not in graph.nodes:
raise ValueError(
f"Prune target {vid!r} is not a graph node (expected val:<instance.path> for a compiled value slot)."
)
needed = graph.dependency_closure(seeds)
sub = graph.induced_subgraph(needed)
sub_handlers = {k: handlers[k] for k in sub.nodes if k in handlers}
return sub, sub_handlers
def _maybe_assert_coherence(
configured_model: ConfiguredModel | None,
graph: DependencyGraph,
parameter_slots: Sequence[ValueSlot],
prune_slots: Sequence[ValueSlot] | None,
) -> None:
if configured_model is None:
return
to_check = list(parameter_slots)
if prune_slots:
to_check.extend(prune_slots)
assert_slots_align_with_graph(
configured_model,
graph,
to_check,
context="sweep",
)
[docs]
def sweep(
*,
graph: DependencyGraph,
handlers: dict[str, Any],
parameter_values: Mapping[ValueSlot, Sequence[Any]],
configured_model: ConfiguredModel | None = None,
prune_to_slots: Sequence[ValueSlot] | None = None,
collect: bool = True,
sink: Callable[[SweepRecord], None] | None = None,
) -> list[SweepRecord]:
"""Cartesian product over ``parameter_values``; one synchronous evaluation per tuple.
Parameters
----------
graph, handlers
From :func:`~tg_model.execution.graph_compiler.compile_graph`.
parameter_values
Maps each parameter :class:`~tg_model.execution.value_slots.ValueSlot` to a sequence
of values (axes). Dimension order is sorted by ``stable_id``.
configured_model : ConfiguredModel, optional
When passed, asserts sweep slots match the graph (coherence check).
prune_to_slots : sequence of ValueSlot, optional
Restricts to upstream closure of these slots (see module warnings).
collect : bool, default True
When False, return an empty list and stream via ``sink``.
sink : callable, optional
Receives each :class:`SweepRecord` when provided.
Returns
-------
list of SweepRecord
All samples when ``collect`` is True.
Raises
------
ValueError
If ``collect=False`` without ``sink``, or prune targets are not graph nodes.
"""
if not collect and sink is None:
raise ValueError("sweep(..., collect=False) requires a sink callable")
slots_sorted, combo_lists = plan_sweep_axes(parameter_values)
_maybe_assert_coherence(configured_model, graph, slots_sorted, prune_to_slots)
eval_graph, eval_handlers = _prepare_pruned(graph, handlers, prune_to_slots)
evaluator = Evaluator(eval_graph, compute_handlers=eval_handlers)
records: list[SweepRecord] = []
for idx, combo in enumerate(itertools.product(*combo_lists)):
inputs: dict[str, Any] = {}
for slot, val in zip(slots_sorted, combo, strict=True):
inputs[slot.stable_id] = val
ctx = RunContext()
result = evaluator.evaluate(ctx, inputs=inputs)
rec = SweepRecord(index=idx, inputs=dict(inputs), result=result)
if sink is not None:
sink(rec)
if collect:
records.append(rec)
return records
[docs]
async def sweep_async(
*,
configured_model: ConfiguredModel,
graph: DependencyGraph,
handlers: dict[str, Any],
parameter_values: Mapping[ValueSlot, Sequence[Any]],
prune_to_slots: Sequence[ValueSlot] | None = None,
collect: bool = True,
sink: Callable[[SweepRecord], None] | None = None,
) -> list[SweepRecord]:
"""Like :func:`sweep` but awaits :meth:`~tg_model.execution.evaluator.Evaluator.evaluate_async`.
Parameters
----------
configured_model : ConfiguredModel
Required for async externals and always used for coherence checks.
graph, handlers, parameter_values, prune_to_slots, collect, sink
Same semantics as :func:`sweep`.
Returns
-------
list of SweepRecord
Same as :func:`sweep`.
Raises
------
ValueError
Same as :func:`sweep`.
"""
if not collect and sink is None:
raise ValueError("sweep_async(..., collect=False) requires a sink callable")
slots_sorted, combo_lists = plan_sweep_axes(parameter_values)
_maybe_assert_coherence(configured_model, graph, slots_sorted, prune_to_slots)
eval_graph, eval_handlers = _prepare_pruned(graph, handlers, prune_to_slots)
evaluator = Evaluator(eval_graph, compute_handlers=eval_handlers)
records: list[SweepRecord] = []
for idx, combo in enumerate(itertools.product(*combo_lists)):
inputs: dict[str, Any] = {}
for slot, val in zip(slots_sorted, combo, strict=True):
inputs[slot.stable_id] = val
ctx = RunContext()
result = await evaluator.evaluate_async(
ctx,
configured_model=configured_model,
inputs=inputs,
)
rec = SweepRecord(index=idx, inputs=dict(inputs), result=result)
if sink is not None:
sink(rec)
if collect:
records.append(rec)
return records