Skip to content

Study

Multi-phase study orchestration.

trade_study.Phase(name, grid, filter_fn=None, n_trials=100, world=None, scorer=None) dataclass

A single phase in a multi-phase study.

Attributes:

Name Type Description
name str

Phase identifier (e.g. "discovery", "refinement").

grid list[dict[str, Any]] | str | GridCallable

Explicit config list, "carry" to re-use filtered configs from the previous phase, "adaptive" for optuna-driven search, or a callable (ResultsTable, list[Observable]) -> list[dict] that dynamically generates the grid from the previous phase's results.

filter_fn Callable[[ResultsTable, list[Observable]], NDArray[intp]] | None

Optional callable that takes a ResultsTable and returns indices of configs to pass to the next phase. If None, phase is terminal.

n_trials int

For adaptive mode, number of optuna trials.

world Simulator | None

Optional phase-level simulator override. When set, this phase uses world instead of the Study-level simulator. Useful for multi-fidelity workflows (cheap surrogate first, expensive model later).

scorer Scorer | None

Optional phase-level scorer override. When set, this phase uses scorer instead of the Study-level scorer.

trade_study.Study(world, scorer, observables, phases, annotations=list(), factors=list()) dataclass

Multi-phase model criticism study.

Attributes:

Name Type Description
world Simulator

Simulator generating (truth, observations).

scorer Scorer

Scorer evaluating observables against truth.

observables list[Observable]

Observable definitions.

phases list[Phase]

Ordered list of study phases.

annotations list[Annotation]

External information (costs, constraints).

factors list[Any]

Factor definitions (needed for adaptive mode).

run(*, n_jobs=1, callback=None)

Execute all phases sequentially.

Parameters:

Name Type Description Default
n_jobs int

Number of parallel workers for grid phases.

1
callback ProgressCallback | None

Optional progress callback invoked after each trial with (trial_index, total_trials, trial_result).

None

Raises:

Type Description
ValueError

If a callable grid is used on the first phase (no previous results to pass).

Source code in src/trade_study/study.py
def run(
    self,
    *,
    n_jobs: int = 1,
    callback: ProgressCallback | None = None,
) -> None:
    """Execute all phases sequentially.

    Args:
        n_jobs: Number of parallel workers for grid phases.
        callback: Optional progress callback invoked after each trial
            with ``(trial_index, total_trials, trial_result)``.

    Raises:
        ValueError: If a callable grid is used on the first phase
            (no previous results to pass).
    """
    carry_grid: list[dict[str, Any]] | None = None
    prev_result: ResultsTable | None = None

    for phase in self.phases:
        # Resolve phase-level overrides (multi-fidelity support)
        world = phase.world if phase.world is not None else self.world
        scorer = phase.scorer if phase.scorer is not None else self.scorer

        if isinstance(phase.grid, str) and phase.grid == "adaptive":
            result = run_adaptive(
                world,
                scorer,
                self.factors,
                self.observables,
                n_trials=phase.n_trials,
            )
        elif callable(phase.grid):
            if prev_result is None:
                msg = (
                    f"Phase {phase.name!r}: callable grid requires a previous phase"
                )
                raise ValueError(msg)
            grid = phase.grid(prev_result, self.observables)
            result = run_grid(
                world,
                scorer,
                grid,
                self.observables,
                annotations=self.annotations or None,
                n_jobs=n_jobs,
                callback=callback,
            )
        else:
            grid = (
                phase.grid if isinstance(phase.grid, list) else (carry_grid or [])
            )
            result = run_grid(
                world,
                scorer,
                grid,
                self.observables,
                annotations=self.annotations or None,
                n_jobs=n_jobs,
                callback=callback,
            )

        self._results[phase.name] = result
        prev_result = result

        if phase.filter_fn is not None:
            keep = phase.filter_fn(result, self.observables)
            carry_grid = [result.configs[i] for i in keep]
        else:
            carry_grid = None

results(phase)

Get results for a specific phase.

Returns:

Type Description
ResultsTable

ResultsTable for the given phase.

Source code in src/trade_study/study.py
def results(self, phase: str) -> ResultsTable:
    """Get results for a specific phase.

    Returns:
        ResultsTable for the given phase.
    """
    return self._results[phase]

front(phase)

Get Pareto front indices for a phase.

Returns:

Type Description
NDArray[intp]

Integer array of Pareto-optimal row indices.

Source code in src/trade_study/study.py
def front(self, phase: str) -> NDArray[np.intp]:
    """Get Pareto front indices for a phase.

    Returns:
        Integer array of Pareto-optimal row indices.
    """
    r = self._results[phase]
    dirs = [o.direction for o in self.observables]
    wts = [o.weight for o in self.observables]
    return extract_front(r.scores, dirs, wts)

front_hypervolume(phase, ref_point)

Compute hypervolume of the Pareto front for a phase.

Returns:

Type Description
float

Hypervolume value.

Source code in src/trade_study/study.py
def front_hypervolume(
    self,
    phase: str,
    ref_point: NDArray[np.floating[Any]],
) -> float:
    """Compute hypervolume of the Pareto front for a phase.

    Returns:
        Hypervolume value.
    """
    r = self._results[phase]
    dirs = [o.direction for o in self.observables]
    wts = [o.weight for o in self.observables]
    front_idx = extract_front(r.scores, dirs, wts)
    return hypervolume(r.scores[front_idx], ref_point, dirs, wts)

stack(phase, *, maximize=False)

Compute score-based stacking weights for a phase.

Returns:

Type Description
NDArray[floating[Any]]

Array of stacking weights.

Source code in src/trade_study/study.py
def stack(
    self,
    phase: str,
    *,
    maximize: bool = False,
) -> NDArray[np.floating[Any]]:
    """Compute score-based stacking weights for a phase.

    Returns:
        Array of stacking weights.
    """
    r = self._results[phase]
    return stack_scores(r.scores.T, maximize=maximize)

summary()

Per-phase summary: n_trials, n_front, observable ranges.

Returns:

Type Description
dict[str, dict[str, Any]]

Dictionary mapping phase names to summary statistics.

Source code in src/trade_study/study.py
def summary(self) -> dict[str, dict[str, Any]]:
    """Per-phase summary: n_trials, n_front, observable ranges.

    Returns:
        Dictionary mapping phase names to summary statistics.
    """
    out: dict[str, dict[str, Any]] = {}
    for name, r in self._results.items():
        dirs = [o.direction for o in self.observables]
        wts = [o.weight for o in self.observables]
        front_idx = extract_front(r.scores, dirs, wts)
        out[name] = {
            "n_trials": len(r.configs),
            "n_front": len(front_idx),
            "observable_ranges": {
                obs: {
                    "min": float(np.nanmin(r.scores[:, i])),
                    "max": float(np.nanmax(r.scores[:, i])),
                }
                for i, obs in enumerate(r.observable_names)
            },
        }
    return out

trade_study.top_k_pareto_filter(k, objective_names=None)

Create a filter that keeps the top-K configs by Pareto rank.

Parameters:

Name Type Description Default
k int

Maximum number of configs to keep.

required
objective_names list[str] | None

Subset of observables to use for ranking. If None, uses all observables.

None

Returns:

Type Description
Callable[[ResultsTable, list[Observable]], NDArray[intp]]

Filter function compatible with Phase.filter_fn.

Source code in src/trade_study/study.py
def top_k_pareto_filter(
    k: int,
    objective_names: list[str] | None = None,
) -> Callable[[ResultsTable, list[Observable]], NDArray[np.intp]]:
    """Create a filter that keeps the top-K configs by Pareto rank.

    Args:
        k: Maximum number of configs to keep.
        objective_names: Subset of observables to use for ranking.
            If None, uses all observables.

    Returns:
        Filter function compatible with Phase.filter_fn.
    """

    def _filter(
        results: ResultsTable,
        observables: list[Observable],
    ) -> NDArray[np.intp]:
        if objective_names is not None:
            cols = [results.observable_names.index(n) for n in objective_names]
            scores = results.scores[:, cols]
            subset = [o for o in observables if o.name in objective_names]
            dirs = [o.direction for o in subset]
            wts = [o.weight for o in subset]
        else:
            scores = results.scores
            dirs = [o.direction for o in observables]
            wts = [o.weight for o in observables]

        ranks = pareto_rank(scores, dirs, wts)
        order = np.argsort(ranks)
        return order[:k]

    return _filter

trade_study.weighted_sum_filter(weights, k)

Create a filter that keeps the top-K configs by weighted sum.

Scalarises multiple objectives into a single score via a weighted sum and keeps the k best configs. Scores are min-max normalised before weighting so that objectives on different scales are comparable. MAXIMIZE objectives are negated before normalisation so that lower normalised values are always better.

Parameters:

Name Type Description Default
weights dict[str, float]

Mapping from observable name to its scalarisation weight. Only the named observables are used; the rest are ignored.

required
k int

Maximum number of configs to keep.

required

Returns:

Type Description
Callable[[ResultsTable, list[Observable]], NDArray[intp]]

Filter function compatible with Phase.filter_fn.

Source code in src/trade_study/study.py
def weighted_sum_filter(
    weights: dict[str, float],
    k: int,
) -> Callable[[ResultsTable, list[Observable]], NDArray[np.intp]]:
    """Create a filter that keeps the top-K configs by weighted sum.

    Scalarises multiple objectives into a single score via a weighted sum
    and keeps the ``k`` best configs.  Scores are min-max normalised
    before weighting so that objectives on different scales are
    comparable.  MAXIMIZE objectives are negated before normalisation so
    that lower normalised values are always better.

    Args:
        weights: Mapping from observable name to its scalarisation weight.
            Only the named observables are used; the rest are ignored.
        k: Maximum number of configs to keep.

    Returns:
        Filter function compatible with ``Phase.filter_fn``.
    """

    def _filter(
        results: ResultsTable,
        observables: list[Observable],
    ) -> NDArray[np.intp]:
        obs_lookup = {o.name: o for o in observables}
        cols = [results.observable_names.index(n) for n in weights]
        raw = results.scores[:, cols].copy()

        # Flip MAXIMIZE objectives so lower is always better
        for j, name in enumerate(weights):
            if obs_lookup[name].direction == Direction.MAXIMIZE:
                raw[:, j] = -raw[:, j]

        # Min-max normalise each column to [0, 1]
        col_min = np.nanmin(raw, axis=0)
        col_max = np.nanmax(raw, axis=0)
        span = col_max - col_min
        span[span == 0] = 1.0  # avoid division by zero for constant cols
        normed = (raw - col_min) / span

        w = np.array([weights[n] for n in weights])
        scalar = normed @ w
        order = np.argsort(scalar)
        return order[:k].astype(np.intp)

    return _filter

trade_study.feasibility_filter(constraints)

Create a filter that keeps only designs satisfying all constraints.

Parameters:

Name Type Description Default
constraints list[Constraint]

Constraint objects to evaluate against results.

required

Returns:

Type Description
Callable[[ResultsTable, list[Observable]], NDArray[intp]]

Filter function compatible with Phase.filter_fn.

Source code in src/trade_study/study.py
def feasibility_filter(
    constraints: list[Constraint],
) -> Callable[[ResultsTable, list[Observable]], NDArray[np.intp]]:
    """Create a filter that keeps only designs satisfying all constraints.

    Args:
        constraints: Constraint objects to evaluate against results.

    Returns:
        Filter function compatible with ``Phase.filter_fn``.
    """

    def _filter(
        results: ResultsTable,
        _observables: list[Observable],
    ) -> NDArray[np.intp]:
        mask = results.feasible(constraints)
        return np.nonzero(mask)[0].astype(np.intp)

    return _filter