Skip to content

Surrogate

Cheap regression surrogates fit over a ResultsTable for predicting observables at untested configurations.

Install via the optional extra:

uv pip install 'trade-study[surrogate]'

trade_study.fit_surrogate(results, factors, *, method='gp', seed=0, n_estimators=200)

Fit a per-observable surrogate over a :class:ResultsTable.

Rows whose score column contains NaN are dropped on a per-observable basis (so a partially-evaluated trial still contributes to the observables it does have).

Parameters:

Name Type Description Default
results ResultsTable

A :class:ResultsTable from a previous study run.

required
factors list[Factor]

Factor definitions used to encode results.configs. Must cover every key in the configs.

required
method str

"gp" for Gaussian process (Matern 1.5 + WhiteKernel) or "rf" for a random forest.

'gp'
seed int

Random seed forwarded to the backend estimators.

0
n_estimators int

Number of trees for the "rf" backend; ignored for "gp".

200

Returns:

Type Description
SurrogateModel

A fitted :class:SurrogateModel.

Raises:

Type Description
ValueError

If method is unknown, results is empty, or no observable has at least two non-NaN training rows.

Source code in src/trade_study/surrogate.py
def fit_surrogate(
    results: ResultsTable,
    factors: list[Factor],
    *,
    method: str = "gp",
    seed: int = 0,
    n_estimators: int = 200,
) -> SurrogateModel:
    """Fit a per-observable surrogate over a :class:`ResultsTable`.

    Rows whose score column contains ``NaN`` are dropped on a
    per-observable basis (so a partially-evaluated trial still
    contributes to the observables it does have).

    Args:
        results: A :class:`ResultsTable` from a previous study run.
        factors: Factor definitions used to encode ``results.configs``.
            Must cover every key in the configs.
        method: ``"gp"`` for Gaussian process (Matern 1.5 + WhiteKernel)
            or ``"rf"`` for a random forest.
        seed: Random seed forwarded to the backend estimators.
        n_estimators: Number of trees for the ``"rf"`` backend; ignored
            for ``"gp"``.

    Returns:
        A fitted :class:`SurrogateModel`.

    Raises:
        ValueError: If ``method`` is unknown, ``results`` is empty, or
            no observable has at least two non-NaN training rows.
    """
    if method not in _SUPPORTED_METHODS:
        msg = (
            f"Unknown surrogate method {method!r}. "
            f"Supported: {sorted(_SUPPORTED_METHODS)}"
        )
        raise ValueError(msg)
    if not results.configs:
        msg = "fit_surrogate: results table is empty"
        raise ValueError(msg)

    encoder = _FactorEncoder.from_factors(factors)
    x_full = encoder.transform(results.configs)

    models: list[Any] = []
    fitted_obs: list[str] = []
    for j, name in enumerate(results.observable_names):
        y = results.scores[:, j]
        mask = ~np.isnan(y)
        if int(mask.sum()) < 2:
            continue
        model = _make_estimator(method, seed=seed, n_estimators=n_estimators)
        model.fit(x_full[mask], y[mask])
        models.append(model)
        fitted_obs.append(name)

    if not models:
        msg = (
            "fit_surrogate: no observable has at least 2 non-NaN training "
            "rows; nothing to fit"
        )
        raise ValueError(msg)

    return SurrogateModel(
        method=method,
        encoder=encoder,
        observable_names=fitted_obs,
        models=models,
    )

trade_study.SurrogateModel(method, encoder, observable_names, models) dataclass

Fitted surrogate over a :class:ResultsTable.

Use :func:fit_surrogate to construct one. Per-observable backend estimators are stored in models; encoding is shared across them.

Attributes:

Name Type Description
method str

"gp" or "rf".

encoder _FactorEncoder

Factor encoder used at fit time.

observable_names list[str]

Column names of the predicted observables.

models list[Any]

One fitted scikit-learn estimator per observable.

predict(config)

Predict observables for a single config.

Parameters:

Name Type Description Default
config dict[str, Any]

Factor-keyed config dict.

required

Returns:

Type Description
dict[str, float]

Mapping from observable name to predicted scalar.

Source code in src/trade_study/surrogate.py
def predict(self, config: dict[str, Any]) -> dict[str, float]:
    """Predict observables for a single config.

    Args:
        config: Factor-keyed config dict.

    Returns:
        Mapping from observable name to predicted scalar.
    """
    x = self.encoder.transform([config])
    return {
        name: float(model.predict(x)[0])
        for name, model in zip(self.observable_names, self.models, strict=True)
    }

predict_batch(configs)

Predict observables for a batch of configs.

Parameters:

Name Type Description Default
configs Sequence[dict[str, Any]]

Sequence of factor-keyed config dicts.

required

Returns:

Type Description
dict[str, NDArray[float64]]

Mapping from observable name to a length-len(configs)

dict[str, NDArray[float64]]

array of predictions.

Source code in src/trade_study/surrogate.py
def predict_batch(
    self,
    configs: Sequence[dict[str, Any]],
) -> dict[str, NDArray[np.float64]]:
    """Predict observables for a batch of configs.

    Args:
        configs: Sequence of factor-keyed config dicts.

    Returns:
        Mapping from observable name to a length-``len(configs)``
        array of predictions.
    """
    x = self.encoder.transform(configs)
    return {
        name: np.asarray(model.predict(x), dtype=np.float64)
        for name, model in zip(self.observable_names, self.models, strict=True)
    }

uncertainty(config)

Predictive standard deviation per observable (GP only).

Parameters:

Name Type Description Default
config dict[str, Any]

Factor-keyed config dict.

required

Returns:

Type Description
dict[str, float]

Mapping from observable name to predictive standard deviation.

Raises:

Type Description
NotImplementedError

If the backend does not expose calibrated uncertainties (currently anything other than "gp").

Source code in src/trade_study/surrogate.py
def uncertainty(self, config: dict[str, Any]) -> dict[str, float]:
    """Predictive standard deviation per observable (GP only).

    Args:
        config: Factor-keyed config dict.

    Returns:
        Mapping from observable name to predictive standard deviation.

    Raises:
        NotImplementedError: If the backend does not expose calibrated
            uncertainties (currently anything other than ``"gp"``).
    """
    if self.method != "gp":
        msg = (
            f"uncertainty() is only supported for method='gp'; "
            f"this surrogate uses method={self.method!r}"
        )
        raise NotImplementedError(msg)
    x = self.encoder.transform([config])
    out: dict[str, float] = {}
    for name, model in zip(self.observable_names, self.models, strict=True):
        _, std = model.predict(x, return_std=True)
        out[name] = float(std[0])
    return out