Skip to content

Design

Experimental design: factors, grids, and screening.

trade_study.FactorType

Bases: Enum

Type of design factor.

trade_study.Factor(name, factor_type, levels=None, bounds=None) dataclass

A single design factor.

Attributes:

Name Type Description
name str

Factor identifier (e.g. "alpha", "layer1_method").

factor_type FactorType

Continuous, discrete, or categorical.

levels list[Any] | None

For categorical/discrete: list of allowed values.

bounds tuple[float, float] | None

For continuous: (low, high) tuple.

__post_init__()

Validate factor constraints.

Raises:

Type Description
ValueError

If name is empty, continuous factor has missing or invalid bounds, or discrete/categorical factor has empty levels.

Source code in src/trade_study/design.py
def __post_init__(self) -> None:
    """Validate factor constraints.

    Raises:
        ValueError: If name is empty, continuous factor has missing or
            invalid bounds, or discrete/categorical factor has empty
            levels.
    """
    if not self.name:
        msg = "Factor name must be a non-empty string"
        raise ValueError(msg)
    if self.factor_type == FactorType.CONTINUOUS:
        if self.bounds is None:
            msg = f"Continuous factor '{self.name}' requires bounds"
            raise ValueError(msg)
        lo, hi = self.bounds
        if not (np.isfinite(lo) and np.isfinite(hi)):
            msg = f"Continuous factor '{self.name}' bounds must be finite"
            raise ValueError(msg)
        if lo >= hi:
            msg = f"Continuous factor '{self.name}' requires lo < hi"
            raise ValueError(msg)
    else:
        if self.levels is None:
            msg = f"Factor '{self.name}' of type {self.factor_type} requires levels"
            raise ValueError(msg)
        if len(self.levels) == 0:
            msg = f"Factor '{self.name}' levels must be non-empty"
            raise ValueError(msg)

trade_study.build_grid(factors, *, method='full', n_samples=100, seed=42, scramble=True)

Build an experimental design grid.

Parameters:

Name Type Description Default
factors list[Factor]

List of design factors.

required
method str

Design method. One of: - "full": Full factorial (categorical/discrete only). - "lhs": Latin hypercube sampling (continuous factors, maps categorical factors to uniform random selection). - "sobol": Scrambled Sobol' sequence via scipy.stats.qmc. - "halton": Scrambled Halton sequence via scipy.stats.qmc.

'full'
n_samples int

Number of samples for LHS / QMC methods.

100
seed int

Random seed.

42
scramble bool

Whether to apply scrambling to QMC sequences (Sobol / Halton). Ignored for other methods.

True

Returns:

Type Description
list[dict[str, Any]]

List of config dictionaries, one per design point.

Raises:

Type Description
ValueError

If an unknown design method is specified.

Source code in src/trade_study/design.py
def build_grid(
    factors: list[Factor],
    *,
    method: str = "full",
    n_samples: int = 100,
    seed: int = 42,
    scramble: bool = True,
) -> list[dict[str, Any]]:
    """Build an experimental design grid.

    Args:
        factors: List of design factors.
        method: Design method. One of:
            - "full": Full factorial (categorical/discrete only).
            - "lhs": Latin hypercube sampling (continuous factors, maps
              categorical factors to uniform random selection).
            - "sobol": Scrambled Sobol' sequence via ``scipy.stats.qmc``.
            - "halton": Scrambled Halton sequence via ``scipy.stats.qmc``.
        n_samples: Number of samples for LHS / QMC methods.
        seed: Random seed.
        scramble: Whether to apply scrambling to QMC sequences (Sobol /
            Halton). Ignored for other methods.

    Returns:
        List of config dictionaries, one per design point.

    Raises:
        ValueError: If an unknown design method is specified.
    """
    if method == "full":
        return _full_factorial(factors)
    if method == "lhs":
        return _latin_hypercube(factors, n_samples=n_samples, seed=seed)
    if method in {"sobol", "halton"}:
        return _qmc_sample(
            factors,
            n_samples=n_samples,
            seed=seed,
            qmc_method=method,
            scramble=scramble,
        )
    msg = f"Unknown design method: {method!r}"
    raise ValueError(msg)

trade_study.reduce_factors(factors, importance, *, threshold=0.1)

Keep only factors whose max importance exceeds threshold.

Parameters:

Name Type Description Default
factors list[Factor]

Original factor list.

required
importance dict[str, NDArray[floating[Any]]]

Output of screen().

required
threshold float

Minimum importance to retain a factor.

0.1

Returns:

Type Description
list[Factor]

Reduced list of influential factors.

Source code in src/trade_study/design.py
def reduce_factors(
    factors: list[Factor],
    importance: dict[str, NDArray[np.floating[Any]]],
    *,
    threshold: float = 0.1,
) -> list[Factor]:
    """Keep only factors whose max importance exceeds threshold.

    Args:
        factors: Original factor list.
        importance: Output of ``screen()``.
        threshold: Minimum importance to retain a factor.

    Returns:
        Reduced list of influential factors.
    """
    continuous = [f for f in factors if f.factor_type == FactorType.CONTINUOUS]
    non_continuous = [f for f in factors if f.factor_type != FactorType.CONTINUOUS]

    max_importance = np.zeros(len(continuous))
    for arr in importance.values():
        max_importance = np.maximum(max_importance, arr)

    kept = [
        f for f, imp in zip(continuous, max_importance, strict=True) if imp >= threshold
    ]
    return non_continuous + kept

trade_study.screen(run_fn, factors, *, method='morris', n_trajectories=100, seed=42)

Screen factors for influence on observables via SALib.

Parameters:

Name Type Description Default
run_fn Callable[[dict[str, Any]], dict[str, float]]

Callable that takes a config dict and returns a dict of observable name → scalar score.

required
factors list[Factor]

List of continuous factors to screen.

required
method str

Screening method ("morris" or "sobol").

'morris'
n_trajectories int

Number of Morris trajectories. For Sobol, this controls the base sample size N; the total number of model evaluations is N x (num_vars + 2).

100
seed int

Random seed.

42

Returns:

Type Description
dict[str, NDArray[floating[Any]]]

Dictionary mapping observable names to arrays of factor importance

dict[str, NDArray[floating[Any]]]

(mu_star for Morris, S1 for Sobol), one value per factor.

Raises:

Type Description
ValueError

If method is unknown or no continuous factors are provided.

Source code in src/trade_study/design.py
def screen(
    run_fn: Callable[[dict[str, Any]], dict[str, float]],
    factors: list[Factor],
    *,
    method: str = "morris",
    n_trajectories: int = 100,
    seed: int = 42,
) -> dict[str, NDArray[np.floating[Any]]]:
    """Screen factors for influence on observables via SALib.

    Args:
        run_fn: Callable that takes a config dict and returns a dict of
            observable name → scalar score.
        factors: List of continuous factors to screen.
        method: Screening method (``"morris"`` or ``"sobol"``).
        n_trajectories: Number of Morris trajectories.  For Sobol, this
            controls the base sample size *N*; the total number of model
            evaluations is *N* x (num_vars + 2).
        seed: Random seed.

    Returns:
        Dictionary mapping observable names to arrays of factor importance
        (mu_star for Morris, S1 for Sobol), one value per factor.

    Raises:
        ValueError: If *method* is unknown or no continuous factors are
            provided.
    """
    continuous = [f for f in factors if f.factor_type == FactorType.CONTINUOUS]
    if not continuous:
        msg = "Screening requires at least one continuous factor"
        raise ValueError(msg)

    problem: dict[str, Any] = {
        "num_vars": len(continuous),
        "names": [f.name for f in continuous],
        "bounds": [list(f.bounds) for f in continuous if f.bounds is not None],
    }

    if method == "morris":
        return _screen_morris(run_fn, problem, n_trajectories, seed)
    if method == "sobol":
        return _screen_sobol(run_fn, problem, n_trajectories, seed)

    msg = f"Unknown screening method: {method!r}"
    raise ValueError(msg)