Source code for tg_model.integrations.external_compute

"""External computation protocols, bindings, and route wiring.

Bindings connect declared :class:`~tg_model.model.refs.AttributeRef` inputs to Python
callables that return :class:`ExternalComputeResult`. Sync paths use
:class:`ExternalCompute`; async callables use :class:`AsyncExternalCompute` with
:class:`~tg_model.execution.evaluator.Evaluator.evaluate_async`.
"""

from __future__ import annotations

import inspect
from collections.abc import Mapping
from dataclasses import dataclass, field
from typing import Any, Protocol, runtime_checkable

from unitflow import Quantity

from tg_model.model.refs import AttributeRef


[docs] class ExternalComputeValidationError(ValueError): """Raised when :meth:`ValidatableExternalCompute.validate_binding` rejects a spec."""
[docs] @dataclass(frozen=True) class ExternalComputeResult: """Return value from :meth:`ExternalCompute.compute` / :meth:`AsyncExternalCompute.compute`.""" value: Quantity | Mapping[str, Quantity] provenance: Mapping[str, Any] = field(default_factory=dict)
[docs] @dataclass class ExternalComputeBinding: """Bind a callable to named input refs and optional output route refs.""" external: object inputs: dict[str, AttributeRef] output_routes: dict[str, AttributeRef] | None = None
[docs] @runtime_checkable class ExternalCompute(Protocol): """Synchronous external backend (safe for :meth:`~tg_model.execution.evaluator.Evaluator.evaluate`).""" @property def name(self) -> str: ...
[docs] def compute(self, inputs: dict[str, Quantity]) -> ExternalComputeResult: ...
[docs] @runtime_checkable class AsyncExternalCompute(Protocol): """Coroutine ``compute`` backend (requires :meth:`~tg_model.execution.evaluator.Evaluator.evaluate_async`).""" @property def name(self) -> str: ...
[docs] async def compute(self, inputs: dict[str, Quantity]) -> ExternalComputeResult: ...
[docs] @runtime_checkable class ValidatableExternalCompute(Protocol): """Optional static validation hook used by :func:`tg_model.execution.validation.validate_graph`."""
[docs] def validate_binding( self, *, input_specs: Mapping[str, Any], output_specs: Mapping[str, Any], ) -> None: """Raise :class:`ExternalComputeValidationError` or ``ValueError`` when specs are inconsistent."""
[docs] def is_async_external(obj: object) -> bool: """Return True when ``compute`` is a coroutine function. Notes ----- Async externals must not run under :meth:`~tg_model.execution.evaluator.Evaluator.evaluate`. """ fn = getattr(obj, "compute", None) return inspect.iscoroutinefunction(fn)
[docs] def assert_sync_external(obj: object, *, context: str = "") -> None: """Raise ``TypeError`` if ``obj`` is an async external. Parameters ---------- obj : object External implementation about to run under sync evaluation. context : str, optional Suffix appended to the error message. Raises ------ TypeError When :func:`is_async_external` is true. """ if is_async_external(obj): suffix = f" ({context})" if context else "" raise TypeError( f"Async external compute {getattr(obj, 'name', obj)!r} cannot run under " f"synchronous Evaluator.evaluate(){suffix}. Use Evaluator.evaluate_async(...)." )