Surrogate¶
Cheap regression surrogates fit over a ResultsTable
for predicting observables at untested configurations.
Install via the optional extra:
trade_study.fit_surrogate(results, factors, *, method='gp', seed=0, n_estimators=200)
¶
Fit a per-observable surrogate over a :class:ResultsTable.
Rows whose score column contains NaN are dropped on a
per-observable basis (so a partially-evaluated trial still
contributes to the observables it does have).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
results
|
ResultsTable
|
A :class: |
required |
factors
|
list[Factor]
|
Factor definitions used to encode |
required |
method
|
str
|
|
'gp'
|
seed
|
int
|
Random seed forwarded to the backend estimators. |
0
|
n_estimators
|
int
|
Number of trees for the |
200
|
Returns:
| Type | Description |
|---|---|
SurrogateModel
|
A fitted :class: |
Raises:
| Type | Description |
|---|---|
ValueError
|
If |
Source code in src/trade_study/surrogate.py
trade_study.SurrogateModel(method, encoder, observable_names, models)
dataclass
¶
Fitted surrogate over a :class:ResultsTable.
Use :func:fit_surrogate to construct one. Per-observable backend
estimators are stored in models; encoding is shared across them.
Attributes:
| Name | Type | Description |
|---|---|---|
method |
str
|
|
encoder |
_FactorEncoder
|
Factor encoder used at fit time. |
observable_names |
list[str]
|
Column names of the predicted observables. |
models |
list[Any]
|
One fitted scikit-learn estimator per observable. |
predict(config)
¶
Predict observables for a single config.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
dict[str, Any]
|
Factor-keyed config dict. |
required |
Returns:
| Type | Description |
|---|---|
dict[str, float]
|
Mapping from observable name to predicted scalar. |
Source code in src/trade_study/surrogate.py
predict_batch(configs)
¶
Predict observables for a batch of configs.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
configs
|
Sequence[dict[str, Any]]
|
Sequence of factor-keyed config dicts. |
required |
Returns:
| Type | Description |
|---|---|
dict[str, NDArray[float64]]
|
Mapping from observable name to a length- |
dict[str, NDArray[float64]]
|
array of predictions. |
Source code in src/trade_study/surrogate.py
uncertainty(config)
¶
Predictive standard deviation per observable (GP only).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
dict[str, Any]
|
Factor-keyed config dict. |
required |
Returns:
| Type | Description |
|---|---|
dict[str, float]
|
Mapping from observable name to predictive standard deviation. |
Raises:
| Type | Description |
|---|---|
NotImplementedError
|
If the backend does not expose calibrated
uncertainties (currently anything other than |