Source code for dendros._mcmc._plots

"""Corner-plot wrapper around the ``corner`` package."""
from __future__ import annotations

from typing import Iterable, Optional, Sequence

import numpy as np

from ._chains import ChainSet


def _import_corner():
    try:
        import corner  # noqa: F401
    except ImportError as exc:
        raise ImportError(
            "corner_plot requires the `corner` package.  Install it with: "
            "pip install 'dendros[mcmc]'  (which also installs matplotlib)."
        ) from exc
    return corner


[docs] def corner_plot( chains: ChainSet, *, parameters: Optional[Iterable[str]] = None, post_burn: Optional[int] = None, drop_chains: Sequence[int] = (), labels: Optional[Sequence[str]] = None, **corner_kwargs, ): """Render a corner plot of post-burn chain samples. Parameters ---------- chains: :class:`ChainSet` whose post-burn samples will be plotted. parameters: Optional iterable of parameter names to restrict the plot to a subset (in the order given). ``None`` plots every active parameter. post_burn: Number of leading rows to skip per chain. ``None`` triggers automatic detection via :func:`gelman_rubin` / :func:`convergence_step`. drop_chains: Iterable of ``chain_index`` values to exclude. labels: Optional axis labels. ``None`` uses each parameter's LaTeX :attr:`ModelParameter.display_label`, wrapped in ``$...$`` so :mod:`corner` renders them in math mode. **corner_kwargs: Additional keyword arguments forwarded to :func:`corner.corner`. Returns ------- matplotlib.figure.Figure Raises ------ ImportError If the optional ``corner`` package is not installed. Install via ``pip install 'dendros[mcmc]'``. KeyError If a name in *parameters* is not among the active parameters. ValueError If the post-burn pool is empty. """ corner = _import_corner() from ._convergence import _resolve_post_burn burn = _resolve_post_burn(chains, post_burn) drop = set(int(i) for i in drop_chains) config = chains.config if parameters is None: cols = list(range(len(config.parameters))) else: index_by_name = {p.name: i for i, p in enumerate(config.parameters)} cols = [] for name in parameters: if name not in index_by_name: raise KeyError( f"Unknown parameter name {name!r}; " f"available: {list(index_by_name)!r}" ) cols.append(index_by_name[name]) parts = [ c.state[burn:] for c in chains if c.chain_index not in drop and c.n_steps > burn ] if not parts: raise ValueError( "No post-burn samples available — every chain was dropped or " "shorter than post_burn." ) samples = np.concatenate(parts, axis=0)[:, cols] if labels is None: labels = [f"${config.parameters[i].display_label}$" for i in cols] else: labels = list(labels) if len(labels) != len(cols): raise ValueError( f"labels has length {len(labels)}; expected {len(cols)} " "(one per chosen parameter)." ) return corner.corner(samples, labels=labels, **corner_kwargs)