Skip to content

Regime Surrogate

Regime-conditional surrogate that interpolates factor recommendations across regime descriptors (e.g. dataset size, noise level) instead of relying on hard regime buckets. Builds on fit_surrogate.

Install via the optional extra (same as the base surrogate):

uv pip install 'trade-study[surrogate]'

trade_study.fit_regime_surrogate(results, regime_factors, factors, *, method='gp', seed=0, n_estimators=200)

Fit a surrogate that conditions on regime features.

Internally fits a single :class:SurrogateModel over the joint regime_factors + factors space, so observables can be interpolated across continuous regime axes.

Every config in results.configs must contain values for both the regime features and the design factors.

Parameters:

Name Type Description Default
results ResultsTable

A :class:ResultsTable from previous study runs that spans multiple regimes.

required
regime_factors list[Factor]

Factors describing the regime (additional input dimensions of the surrogate; typically continuous).

required
factors list[Factor]

Tunable design factors. Together with regime_factors these must cover every key referenced in results.configs.

required
method str

Surrogate backend, "gp" or "rf". See :func:trade_study.fit_surrogate.

'gp'
seed int

Random seed forwarded to the backend estimators.

0
n_estimators int

Number of trees for the "rf" backend.

200

Returns:

Type Description
RegimeSurrogate

A fitted :class:RegimeSurrogate.

Raises:

Type Description
ValueError

If regime_factors is empty, if a name appears in both regime_factors and factors, or if the underlying :func:fit_surrogate call fails.

Source code in src/trade_study/regime.py
def fit_regime_surrogate(
    results: ResultsTable,
    regime_factors: list[Factor],
    factors: list[Factor],
    *,
    method: str = "gp",
    seed: int = 0,
    n_estimators: int = 200,
) -> RegimeSurrogate:
    """Fit a surrogate that conditions on regime features.

    Internally fits a single :class:`SurrogateModel` over the joint
    ``regime_factors + factors`` space, so observables can be
    interpolated across continuous regime axes.

    Every config in ``results.configs`` must contain values for both the
    regime features and the design factors.

    Args:
        results: A :class:`ResultsTable` from previous study runs that
            spans multiple regimes.
        regime_factors: Factors describing the regime (additional input
            dimensions of the surrogate; typically continuous).
        factors: Tunable design factors. Together with
            ``regime_factors`` these must cover every key referenced in
            ``results.configs``.
        method: Surrogate backend, ``"gp"`` or ``"rf"``. See
            :func:`trade_study.fit_surrogate`.
        seed: Random seed forwarded to the backend estimators.
        n_estimators: Number of trees for the ``"rf"`` backend.

    Returns:
        A fitted :class:`RegimeSurrogate`.

    Raises:
        ValueError: If ``regime_factors`` is empty, if a name appears in
            both ``regime_factors`` and ``factors``, or if the
            underlying :func:`fit_surrogate` call fails.
    """
    if not regime_factors:
        msg = "fit_regime_surrogate: regime_factors must be non-empty"
        raise ValueError(msg)
    overlap = {f.name for f in regime_factors} & {f.name for f in factors}
    if overlap:
        msg = (
            f"fit_regime_surrogate: names appear in both regime_factors and "
            f"factors: {sorted(overlap)}"
        )
        raise ValueError(msg)
    inner = fit_surrogate(
        results,
        [*regime_factors, *factors],
        method=method,
        seed=seed,
        n_estimators=n_estimators,
    )
    return RegimeSurrogate(
        inner=inner,
        regime_factors=list(regime_factors),
        factors=list(factors),
    )

trade_study.RegimeSurrogate(inner, regime_factors, factors) dataclass

Surrogate that conditions on regime features.

Wraps a single :class:SurrogateModel fit over the union of regime descriptors and design factors. Use :func:fit_regime_surrogate to construct one.

Attributes:

Name Type Description
inner SurrogateModel

The underlying :class:SurrogateModel over the joint regime_factors + factors input space.

regime_factors list[Factor]

Factors describing the regime (additional input dimensions of the surrogate).

factors list[Factor]

Tunable design factors that are optimized at a given regime by :meth:recommend.

predict(regime, config)

Predict observables at a regime + config pair.

Parameters:

Name Type Description Default
regime dict[str, Any]

Mapping of regime-feature names to values.

required
config dict[str, Any]

Mapping of design-factor names to values.

required

Returns:

Type Description
dict[str, float]

Mapping from observable name to predicted scalar.

Source code in src/trade_study/regime.py
def predict(
    self,
    regime: dict[str, Any],
    config: dict[str, Any],
) -> dict[str, float]:
    """Predict observables at a regime + config pair.

    Args:
        regime: Mapping of regime-feature names to values.
        config: Mapping of design-factor names to values.

    Returns:
        Mapping from observable name to predicted scalar.
    """
    return self.inner.predict(_merge(regime, config))

predict_batch(regime, configs)

Predict observables for a batch of configs at one regime.

Parameters:

Name Type Description Default
regime dict[str, Any]

Mapping of regime-feature names to values.

required
configs Sequence[dict[str, Any]]

Sequence of design-factor configs to score at regime.

required

Returns:

Type Description
dict[str, NDArray[float64]]

Mapping from observable name to a length-len(configs)

dict[str, NDArray[float64]]

array of predictions.

Source code in src/trade_study/regime.py
def predict_batch(
    self,
    regime: dict[str, Any],
    configs: Sequence[dict[str, Any]],
) -> dict[str, NDArray[np.float64]]:
    """Predict observables for a batch of configs at one regime.

    Args:
        regime: Mapping of regime-feature names to values.
        configs: Sequence of design-factor configs to score at
            ``regime``.

    Returns:
        Mapping from observable name to a length-``len(configs)``
        array of predictions.
    """
    merged = [_merge(regime, c) for c in configs]
    return self.inner.predict_batch(merged)

uncertainty(regime, config)

Predictive standard deviation per observable (GP only).

Parameters:

Name Type Description Default
regime dict[str, Any]

Mapping of regime-feature names to values.

required
config dict[str, Any]

Mapping of design-factor names to values.

required

Returns:

Name Type Description
dict[str, float]

Mapping from observable name to predictive standard deviation.

Propagates dict[str, float]

class:NotImplementedError from the underlying

dict[str, float]

surrogate when the backend does not expose calibrated

dict[str, float]

uncertainties (non-GP backends).

Source code in src/trade_study/regime.py
def uncertainty(
    self,
    regime: dict[str, Any],
    config: dict[str, Any],
) -> dict[str, float]:
    """Predictive standard deviation per observable (GP only).

    Args:
        regime: Mapping of regime-feature names to values.
        config: Mapping of design-factor names to values.

    Returns:
        Mapping from observable name to predictive standard deviation.
        Propagates :class:`NotImplementedError` from the underlying
        surrogate when the backend does not expose calibrated
        uncertainties (non-GP backends).
    """
    return self.inner.uncertainty(_merge(regime, config))

recommend(regime, *, objective, mode='min', n_candidates=512, seed=0, candidates=None)

Recommend a design-factor config at a query regime.

Samples n_candidates configs from the design-factor space via a scrambled Sobol' sequence and returns the one whose surrogate prediction for objective is best under mode.

Parameters:

Name Type Description Default
regime dict[str, Any]

Mapping of regime-feature names to values.

required
objective str

Name of the observable to optimize. Must be one of self.inner.observable_names.

required
mode str

"min" or "max".

'min'
n_candidates int

Number of design-space samples to evaluate. Ignored when candidates is provided.

512
seed int

Seed for the Sobol' sampler.

0
candidates Sequence[dict[str, Any]] | None

Optional explicit list of design-factor configs to score; if given, overrides n_candidates.

None

Returns:

Type Description
dict[str, Any]

The candidate config (a copy) achieving the best predicted

dict[str, Any]

objective under mode.

Raises:

Type Description
ValueError

If objective is not a fitted observable, if mode is not "min" or "max", or if there are no candidates to score.

Source code in src/trade_study/regime.py
def recommend(
    self,
    regime: dict[str, Any],
    *,
    objective: str,
    mode: str = "min",
    n_candidates: int = 512,
    seed: int = 0,
    candidates: Sequence[dict[str, Any]] | None = None,
) -> dict[str, Any]:
    """Recommend a design-factor config at a query regime.

    Samples ``n_candidates`` configs from the design-factor space via
    a scrambled Sobol' sequence and returns the one whose surrogate
    prediction for ``objective`` is best under ``mode``.

    Args:
        regime: Mapping of regime-feature names to values.
        objective: Name of the observable to optimize. Must be one
            of ``self.inner.observable_names``.
        mode: ``"min"`` or ``"max"``.
        n_candidates: Number of design-space samples to evaluate.
            Ignored when ``candidates`` is provided.
        seed: Seed for the Sobol' sampler.
        candidates: Optional explicit list of design-factor configs
            to score; if given, overrides ``n_candidates``.

    Returns:
        The candidate config (a copy) achieving the best predicted
        ``objective`` under ``mode``.

    Raises:
        ValueError: If ``objective`` is not a fitted observable, if
            ``mode`` is not ``"min"`` or ``"max"``, or if there are
            no candidates to score.
    """
    if mode not in _SUPPORTED_MODES:
        msg = f"mode must be one of {sorted(_SUPPORTED_MODES)}; got {mode!r}"
        raise ValueError(msg)
    if objective not in self.inner.observable_names:
        msg = (
            f"objective {objective!r} is not a fitted observable; "
            f"available: {self.inner.observable_names}"
        )
        raise ValueError(msg)
    pool = (
        list(candidates)
        if candidates is not None
        else build_grid(
            self.factors,
            method="sobol",
            n_samples=n_candidates,
            seed=seed,
        )
    )
    if not pool:
        msg = "recommend: no candidates to score"
        raise ValueError(msg)
    preds = self.predict_batch(regime, pool)[objective]
    idx = int(np.argmin(preds)) if mode == "min" else int(np.argmax(preds))
    return dict(pool[idx])