Skip to content

Visualization

Plotting utilities for trade-study results.

trade_study.plot_front(results, directions, *, ax=None, front_kw=None, dominated_kw=None)

Plot a Pareto front from a results table.

For two objectives, draws a 2-D scatter. For three objectives, draws a 3-D scatter. For four or more, draws a pairwise scatter matrix of the first three objectives.

Parameters:

Name Type Description Default
results ResultsTable

Scored results from a study phase.

required
directions list[Direction]

Optimization direction per observable.

required
ax Axes | None

Optional axes to draw on (only used for 2-D case).

None
front_kw dict[str, Any] | None

Extra keyword arguments for front-point scatter.

None
dominated_kw dict[str, Any] | None

Extra keyword arguments for dominated-point scatter.

None

Returns:

Type Description
Figure

Tuple of (Figure, Axes). For the pairwise matrix case the

Axes | ndarray[Any, dtype[Any]]

second element is an ndarray of Axes.

Raises:

Type Description
ValueError

If fewer than two objectives are present.

Source code in src/trade_study/viz.py
def plot_front(
    results: ResultsTable,
    directions: list[Direction],
    *,
    ax: Axes | None = None,
    front_kw: dict[str, Any] | None = None,
    dominated_kw: dict[str, Any] | None = None,
) -> tuple[Figure, Axes | np.ndarray[Any, np.dtype[Any]]]:
    """Plot a Pareto front from a results table.

    For two objectives, draws a 2-D scatter.  For three objectives,
    draws a 3-D scatter.  For four or more, draws a pairwise scatter
    matrix of the first three objectives.

    Args:
        results: Scored results from a study phase.
        directions: Optimization direction per observable.
        ax: Optional axes to draw on (only used for 2-D case).
        front_kw: Extra keyword arguments for front-point scatter.
        dominated_kw: Extra keyword arguments for dominated-point scatter.

    Returns:
        Tuple of (Figure, Axes).  For the pairwise matrix case the
        second element is an ndarray of Axes.

    Raises:
        ValueError: If fewer than two objectives are present.
    """
    _require_matplotlib()

    from ._pareto import extract_front

    n_obj = results.scores.shape[1]
    if n_obj < 2:
        msg = "plot_front requires at least 2 objectives"
        raise ValueError(msg)

    front_idx = extract_front(results.scores, directions)
    is_front = np.zeros(len(results.scores), dtype=bool)
    is_front[front_idx] = True

    fkw: dict[str, Any] = {
        "s": 40,
        "zorder": 3,
        "label": "Pareto front",
    }
    if front_kw:
        fkw.update(front_kw)

    dkw: dict[str, Any] = {
        "s": 20,
        "alpha": 0.35,
        "color": "0.6",
        "zorder": 2,
        "label": "Dominated",
    }
    if dominated_kw:
        dkw.update(dominated_kw)

    names = results.observable_names

    if n_obj == 2:
        return _plot_front_2d(results.scores, is_front, names, ax, fkw, dkw)
    if n_obj == 3:
        return _plot_front_3d(results.scores, is_front, names, fkw, dkw)
    return _plot_front_pairs(results.scores, is_front, names, fkw, dkw)

trade_study.plot_parallel(results, directions, *, ax=None, cmap='viridis')

Parallel coordinates plot colored by Pareto rank.

Each vertical axis represents one observable, normalized to [0, 1] with the "better" end pointing up. Lines are colored by Pareto rank (0 = front, darker = better).

Parameters:

Name Type Description Default
results ResultsTable

Scored results from a study phase.

required
directions list[Direction]

Optimization direction per observable.

required
ax Axes | None

Optional axes to draw on.

None
cmap str

Matplotlib colormap name for Pareto-rank coloring.

'viridis'

Returns:

Type Description
tuple[Figure, Axes]

Tuple of (Figure, Axes).

Source code in src/trade_study/viz.py
def plot_parallel(
    results: ResultsTable,
    directions: list[Direction],
    *,
    ax: Axes | None = None,
    cmap: str = "viridis",
) -> tuple[Figure, Axes]:
    """Parallel coordinates plot colored by Pareto rank.

    Each vertical axis represents one observable, normalized to [0, 1]
    with the "better" end pointing up.  Lines are colored by Pareto
    rank (0 = front, darker = better).

    Args:
        results: Scored results from a study phase.
        directions: Optimization direction per observable.
        ax: Optional axes to draw on.
        cmap: Matplotlib colormap name for Pareto-rank coloring.

    Returns:
        Tuple of (Figure, Axes).
    """
    _require_matplotlib()
    import matplotlib.pyplot as plt
    from matplotlib.collections import LineCollection
    from matplotlib.colors import Normalize as MplNormalize

    from ._pareto import pareto_rank

    if ax is None:
        fig, ax = plt.subplots(figsize=(max(6, len(directions) * 1.5), 5))
    else:
        fig = ax.get_figure()  # type: ignore[assignment]

    scores = results.scores
    n_obj = scores.shape[1]
    normed = _normalize_parallel(scores, directions)
    ranks = pareto_rank(scores, directions)
    max_rank = max(int(ranks.max()), 1)

    cm = plt.get_cmap(cmap)
    segments, colors = _build_parallel_lines(normed, ranks, cm)

    lc = LineCollection(segments, colors=colors, linewidths=1.2, alpha=0.7)
    ax.add_collection(lc)

    ax.set_xlim(-0.1, n_obj - 0.9)
    ax.set_ylim(-0.05, 1.05)
    ax.set_xticks(np.arange(n_obj))
    ax.set_xticklabels(results.observable_names)
    ax.set_ylabel("Normalized score (↑ better)")

    sm = plt.cm.ScalarMappable(
        cmap=cm,
        norm=MplNormalize(0, max_rank),
    )
    sm.set_array([])
    cbar = fig.colorbar(sm, ax=ax, label="Pareto rank")
    cbar.ax.invert_yaxis()

    return fig, ax

trade_study.plot_scores(results, observable, directions=None, *, ax=None)

Strip plot of one observable across all configurations.

Each dot is one trial. If directions are provided, Pareto-front designs are highlighted.

Parameters:

Name Type Description Default
results ResultsTable

Scored results from a study phase.

required
observable str

Name of the observable to plot.

required
directions list[Direction] | None

If given, highlight Pareto-front designs.

None
ax Axes | None

Optional axes to draw on.

None

Returns:

Type Description
tuple[Figure, Axes]

Tuple of (Figure, Axes).

Raises:

Type Description
ValueError

If the observable name is not found.

Source code in src/trade_study/viz.py
def plot_scores(
    results: ResultsTable,
    observable: str,
    directions: list[Direction] | None = None,
    *,
    ax: Axes | None = None,
) -> tuple[Figure, Axes]:
    """Strip plot of one observable across all configurations.

    Each dot is one trial.  If *directions* are provided, Pareto-front
    designs are highlighted.

    Args:
        results: Scored results from a study phase.
        observable: Name of the observable to plot.
        directions: If given, highlight Pareto-front designs.
        ax: Optional axes to draw on.

    Returns:
        Tuple of (Figure, Axes).

    Raises:
        ValueError: If the observable name is not found.
    """
    _require_matplotlib()
    import matplotlib.pyplot as plt

    if observable not in results.observable_names:
        msg = f"Observable {observable!r} not in {results.observable_names}"
        raise ValueError(msg)

    col = results.observable_names.index(observable)
    values = results.scores[:, col]

    if ax is None:
        fig, ax = plt.subplots()
    else:
        fig = ax.get_figure()  # type: ignore[assignment]

    jitter = np.random.default_rng(0).uniform(-0.15, 0.15, size=len(values))

    if directions is not None:
        from ._pareto import extract_front

        front_idx = extract_front(results.scores, directions)
        is_front = np.zeros(len(values), dtype=bool)
        is_front[front_idx] = True

        ax.scatter(
            jitter[~is_front],
            values[~is_front],
            s=20,
            alpha=0.4,
            color="0.6",
            label="Dominated",
        )
        ax.scatter(
            jitter[is_front],
            values[is_front],
            s=40,
            zorder=3,
            label="Pareto front",
        )
        ax.legend()
    else:
        ax.scatter(jitter, values, s=20, alpha=0.6)

    ax.set_ylabel(observable)
    ax.set_xticks([])

    return fig, ax

trade_study.plot_calibration(nominal, empirical, *, ax=None)

Plot a calibration curve from coverage_curve() output.

Compares nominal coverage levels against empirical coverage. A well-calibrated model follows the diagonal.

Parameters:

Name Type Description Default
nominal NDArray[floating[Any]]

Nominal coverage levels, shape (n_levels,).

required
empirical NDArray[floating[Any]]

Empirical coverage values, shape (n_levels,).

required
ax Axes | None

Optional axes to draw on.

None

Returns:

Type Description
tuple[Figure, Axes]

Tuple of (Figure, Axes).

Source code in src/trade_study/viz.py
def plot_calibration(
    nominal: NDArray[np.floating[Any]],
    empirical: NDArray[np.floating[Any]],
    *,
    ax: Axes | None = None,
) -> tuple[Figure, Axes]:
    """Plot a calibration curve from ``coverage_curve()`` output.

    Compares nominal coverage levels against empirical coverage.
    A well-calibrated model follows the diagonal.

    Args:
        nominal: Nominal coverage levels, shape (n_levels,).
        empirical: Empirical coverage values, shape (n_levels,).
        ax: Optional axes to draw on.

    Returns:
        Tuple of (Figure, Axes).
    """
    _require_matplotlib()
    import matplotlib.pyplot as plt

    if ax is None:
        fig, ax = plt.subplots()
    else:
        fig = ax.get_figure()  # type: ignore[assignment]

    ax.plot([0, 1], [0, 1], "k--", alpha=0.4, label="Ideal")
    ax.plot(nominal, empirical, "o-", markersize=3, label="Empirical")
    ax.set_xlabel("Nominal coverage")
    ax.set_ylabel("Empirical coverage")
    ax.set_xlim(0, 1)
    ax.set_ylim(0, 1)
    ax.set_aspect("equal")
    ax.legend()

    return fig, ax