"""Cross-variant evaluation: same workflow, isolated runs, aligned outputs (Phase 5).
Each scenario compiles its own graph from its :class:`ConfiguredModel` — structure
may differ per variant. There is no shared :class:`tg_model.execution.run_context.RunContext`.
By default, :func:`validate_graph` runs (with ``configured_model``) before each
evaluation so ill-posed variants fail with :class:`CompareVariantsValidationError`
instead of obscure runtime errors.
"""
from __future__ import annotations
from collections.abc import Mapping, Sequence
from dataclasses import dataclass
from typing import Any, TypeAlias
from tg_model.execution.configured_model import ConfiguredModel
from tg_model.execution.dependency_graph import DependencyGraph
from tg_model.execution.evaluator import Evaluator, RunResult
from tg_model.execution.graph_compiler import compile_graph
from tg_model.execution.run_context import RunContext
from tg_model.execution.validation import validate_graph
from tg_model.execution.value_slots import ValueSlot
VariantScenario: TypeAlias = tuple[str, ConfiguredModel, Mapping[str, Any]] # noqa: UP040
[docs]
class CompareVariantsValidationError(Exception):
"""Raised when :func:`~tg_model.execution.validation.validate_graph` fails for one variant."""
def __init__(self, label: str, failures: list[str]) -> None:
"""Attach ``label`` and ``failures`` to the exception for programmatic handling."""
self.label = label
self.failures = list(failures)
detail = "; ".join(failures)
super().__init__(f"validate_graph failed for variant {label!r}: {detail}")
[docs]
@dataclass(frozen=True)
class CapturedSlotOutput:
"""Resolved output for one ``output_paths`` entry in :class:`VariantComparisonRow`."""
value: Any | None
present_in_run_outputs: bool
"""True iff the slot's ``stable_id`` was present in ``RunResult.outputs``."""
@property
def realized(self) -> bool:
return self.present_in_run_outputs
[docs]
@dataclass(frozen=True)
class VariantComparisonRow:
"""One scenario row from :func:`compare_variants`."""
label: str
outputs: dict[str, CapturedSlotOutput]
result: RunResult
def _assert_same_root_if_requested(
scenarios: Sequence[VariantScenario],
*,
require_same_root_definition_type: bool,
) -> None:
if not require_same_root_definition_type or len(scenarios) <= 1:
return
t0 = scenarios[0][1].root.definition_type
for label, cm, _ in scenarios[1:]:
if cm.root.definition_type is not t0:
raise ValueError(
f"compare_variants: scenario {label!r} root type "
f"{cm.root.definition_type!r} differs from the first scenario ({t0!r}). "
"Set require_same_root_definition_type=False to compare structurally "
"different roots, or align your configured models."
)
def _compile_and_maybe_validate(
label: str,
cm: ConfiguredModel,
*,
validate_before_run: bool,
) -> tuple[DependencyGraph, dict[str, Any]]:
graph, handlers = compile_graph(cm)
if validate_before_run:
v = validate_graph(graph, configured_model=cm)
if not v.passed:
raise CompareVariantsValidationError(label, [f.message for f in v.failures])
return graph, handlers
def _collect_outputs(cm: ConfiguredModel, result: RunResult, paths: Sequence[str]) -> dict[str, CapturedSlotOutput]:
out: dict[str, CapturedSlotOutput] = {}
for path in paths:
handle = cm.handle(path)
if not isinstance(handle, ValueSlot):
raise TypeError(f"output_paths entry {path!r} must resolve to a ValueSlot, got {type(handle).__name__}")
sid = handle.stable_id
present = sid in result.outputs
out[path] = CapturedSlotOutput(
value=result.outputs[sid] if present else None,
present_in_run_outputs=present,
)
return out
[docs]
def compare_variants(
*,
scenarios: Sequence[VariantScenario],
output_paths: Sequence[str],
validate_before_run: bool = True,
require_same_root_definition_type: bool = False,
) -> list[VariantComparisonRow]:
"""Evaluate each ``(label, configured_model, inputs)`` with a fresh graph and context.
Parameters
----------
scenarios
Sequence of ``(label, configured_model, inputs)`` tuples.
output_paths
Dotted paths that must resolve to :class:`~tg_model.execution.value_slots.ValueSlot`.
validate_before_run : bool, default True
Run :func:`~tg_model.execution.validation.validate_graph` before each evaluation.
require_same_root_definition_type : bool, default False
When True, all roots must share the same Python type.
Returns
-------
list of VariantComparisonRow
One row per scenario in input order.
Raises
------
CompareVariantsValidationError
When validation fails for a labeled scenario.
ValueError
When ``require_same_root_definition_type`` is violated.
TypeError
When an ``output_paths`` entry does not resolve to a ``ValueSlot``.
Notes
-----
``inputs`` maps ``ValueSlot.stable_id`` strings to values, same as
:meth:`~tg_model.execution.evaluator.Evaluator.evaluate`. ``CapturedSlotOutput``
distinguishes absent outputs from ``None`` values via ``present_in_run_outputs``.
"""
_assert_same_root_if_requested(scenarios, require_same_root_definition_type=require_same_root_definition_type)
rows: list[VariantComparisonRow] = []
for label, cm, inputs in scenarios:
graph, handlers = _compile_and_maybe_validate(
label,
cm,
validate_before_run=validate_before_run,
)
evaluator = Evaluator(graph, compute_handlers=handlers)
ctx = RunContext()
result = evaluator.evaluate(ctx, inputs=dict(inputs))
outputs = _collect_outputs(cm, result, output_paths)
rows.append(VariantComparisonRow(label=label, outputs=outputs, result=result))
return rows
[docs]
async def compare_variants_async(
*,
scenarios: Sequence[VariantScenario],
output_paths: Sequence[str],
validate_before_run: bool = True,
require_same_root_definition_type: bool = False,
) -> list[VariantComparisonRow]:
"""Async variant of :func:`compare_variants` (uses ``evaluate_async`` per scenario).
Raises
------
CompareVariantsValidationError, ValueError, TypeError
Same families as :func:`compare_variants`.
"""
_assert_same_root_if_requested(scenarios, require_same_root_definition_type=require_same_root_definition_type)
rows: list[VariantComparisonRow] = []
for label, cm, inputs in scenarios:
graph, handlers = _compile_and_maybe_validate(
label,
cm,
validate_before_run=validate_before_run,
)
evaluator = Evaluator(graph, compute_handlers=handlers)
ctx = RunContext()
result = await evaluator.evaluate_async(
ctx,
configured_model=cm,
inputs=dict(inputs),
)
outputs = _collect_outputs(cm, result, output_paths)
rows.append(VariantComparisonRow(label=label, outputs=outputs, result=result))
return rows