Source code for dendros._mcmc._analysis

"""Per-chain analyses that operate on post-burn samples."""
from __future__ import annotations

from dataclasses import dataclass
from typing import List, Optional, Sequence, Tuple, Union

import numpy as np

from ._chains import ChainSet


[docs] def acceptance_rate( chains: ChainSet, *, post_burn: Optional[int] = None, ) -> np.ndarray: """Per-chain post-burn acceptance rate. A step is "accepted" iff *any* parameter component differs from the previous step. Galacticus emits the same row when a proposal is rejected, so this is the canonical acceptance count. Parameters ---------- chains: :class:`ChainSet`. post_burn: Number of leading rows to skip in each chain. ``None`` triggers automatic detection via :func:`gelman_rubin` / :func:`convergence_step`; if convergence is not reached a :class:`UserWarning` is emitted and ``0`` is used. Returns ------- np.ndarray ``(n_chains,)`` array. ``NaN`` for any chain with fewer than two post-burn rows. """ from ._convergence import _resolve_post_burn burn = _resolve_post_burn(chains, post_burn) out = np.full(len(chains), np.nan, dtype=float) for i, c in enumerate(chains): seg = c.state[burn:] if seg.shape[0] < 2: continue diffs = np.any(seg[1:] != seg[:-1], axis=1) out[i] = diffs.mean() return out
[docs] def acceptance_rate_trace( chains: ChainSet, *, window: int = 30, post_burn: int = 0, ) -> List[np.ndarray]: """Sliding-window acceptance rate as a function of step. For each chain, returns a 1-D array whose ``i``-th entry is the fraction of the previous *window* transitions that were accepted (i.e. changed at least one parameter). The first ``window`` entries are filled with :data:`numpy.nan` because the window is not yet full. Parameters ---------- chains: :class:`ChainSet`. window: Sliding-window length in steps. Defaults to ``30`` to match the Galacticus Perl reference. post_burn: Number of leading rows to skip in each chain before evaluation. Returns ------- list of np.ndarray One 1-D array per chain, each of length ``n_steps_post_burn``. Returned as a list because chains may have different post-burn lengths. """ if window < 1: raise ValueError(f"window must be >= 1; got {window!r}") if post_burn < 0: raise ValueError(f"post_burn must be non-negative; got {post_burn!r}") out: List[np.ndarray] = [] for c in chains: seg = c.state[post_burn:] n = seg.shape[0] trace = np.full(n, np.nan, dtype=float) if n > 1: diff = np.any(seg[1:] != seg[:-1], axis=1).astype(int) cumsum = np.concatenate(([0], np.cumsum(diff))) if n - 1 >= window: # trace[i] is the rate over transitions ending at row i. trace[window:] = (cumsum[window:] - cumsum[: n - window]) / window out.append(trace) return out
# --------------------------------------------------------------------------- # Maximum posterior / likelihood # ---------------------------------------------------------------------------
[docs] @dataclass(frozen=True) class MaxResult: """Result of :func:`maximum_posterior` or :func:`maximum_likelihood`. Attributes ---------- state: ``(n_params,)`` parameter vector at the maximizing step. log_posterior: Log posterior at that step. log_likelihood: Log likelihood at that step. chain_index: ``chain_index`` of the chain holding the maximum. step: Simulation-step value of the maximizing row (1-based, matching :attr:`Chain.step`). parameter_names: Names along ``state``. """ state: np.ndarray log_posterior: float log_likelihood: float chain_index: int step: int parameter_names: Tuple[str, ...]
def _argmax_across_chains( chains: ChainSet, *, drop_chains: Sequence[int], field: str, ) -> Tuple[int, int]: """Return ``(chain_position, row)`` of the per-chain maximum of *field*. *chain_position* is the index into ``chains`` (not the chain's own ``chain_index``). Raises :class:`ValueError` if no chains remain. """ drop = set(int(i) for i in drop_chains) best_chain_pos = -1 best_row = -1 best_val = -np.inf for pos, c in enumerate(chains): if c.chain_index in drop: continue arr = getattr(c, field) if arr.size == 0: continue row = int(np.argmax(arr)) val = float(arr[row]) if val > best_val: best_val = val best_chain_pos = pos best_row = row if best_chain_pos < 0: raise ValueError( "No chains contributed a maximum: every chain was either dropped " "or empty." ) return best_chain_pos, best_row
[docs] def maximum_posterior( chains: ChainSet, *, drop_chains: Sequence[int] = (), ) -> MaxResult: """State vector at the maximum log posterior across all surviving chains. Parameters ---------- chains: :class:`ChainSet`. drop_chains: ``chain_index`` values to exclude from the search (e.g. the result of :func:`outlier_chains`). Returns ------- MaxResult """ pos, row = _argmax_across_chains( chains, drop_chains=drop_chains, field="log_posterior" ) c = chains[pos] return MaxResult( state=c.state[row].copy(), log_posterior=float(c.log_posterior[row]), log_likelihood=float(c.log_likelihood[row]), chain_index=int(c.chain_index), step=int(c.step[row]), parameter_names=chains.config.parameter_names, )
[docs] def maximum_likelihood( chains: ChainSet, *, drop_chains: Sequence[int] = (), ) -> MaxResult: """State vector at the maximum log likelihood across all surviving chains.""" pos, row = _argmax_across_chains( chains, drop_chains=drop_chains, field="log_likelihood" ) c = chains[pos] return MaxResult( state=c.state[row].copy(), log_posterior=float(c.log_posterior[row]), log_likelihood=float(c.log_likelihood[row]), chain_index=int(c.chain_index), step=int(c.step[row]), parameter_names=chains.config.parameter_names, )
# --------------------------------------------------------------------------- # Posterior sampling # ---------------------------------------------------------------------------
[docs] @dataclass(frozen=True) class PosteriorSamples: """Sampled rows from the post-burn concatenated chain. Attributes ---------- state: ``(n_samples, n_params)`` parameter values. log_posterior: ``(n_samples,)`` log posterior at each sample. log_likelihood: ``(n_samples,)`` log likelihood at each sample. chain_index: ``(n_samples,)`` source chain indices. step: ``(n_samples,)`` source simulation-step values. parameter_names: Parameter names along ``state.shape[1]``. Notes ----- Adjacent steps in an MCMC chain are correlated — *N* draws here represent significantly fewer independent samples from the posterior. Use :func:`effective_sample_size` to estimate the equivalent count, and thin by the integrated autocorrelation time if independence is required. """ state: np.ndarray log_posterior: np.ndarray log_likelihood: np.ndarray chain_index: np.ndarray step: np.ndarray parameter_names: Tuple[str, ...]
[docs] def posterior_samples( chains: ChainSet, n: int, *, post_burn: Optional[int] = None, drop_chains: Sequence[int] = (), rng: Optional[np.random.Generator] = None, replace: Optional[bool] = None, ) -> PosteriorSamples: """Draw *n* uniformly-random rows from the post-burn concatenated chain. Parameters ---------- chains: :class:`ChainSet`. n: Number of samples to draw. Must be positive. post_burn: Number of leading rows to skip in each chain. ``None`` triggers automatic detection via :func:`gelman_rubin` / :func:`convergence_step`; if convergence is not reached a :class:`UserWarning` is emitted and ``0`` is used. drop_chains: Iterable of ``chain_index`` values to exclude. rng: :class:`numpy.random.Generator`. Defaults to :func:`numpy.random.default_rng()`, which seeds from system entropy. Pass an explicit generator for reproducibility. replace: Whether to sample with replacement. ``None`` (default) means "with replacement only when *n* exceeds the available pool", matching the common case where a small ``n`` is desired and identical rows would be misleading. Returns ------- PosteriorSamples Raises ------ ValueError If *n* is non-positive, or if all chains are dropped, or if ``replace=False`` and *n* exceeds the pool size. """ from ._convergence import _resolve_post_burn if n <= 0: raise ValueError(f"n must be positive; got {n!r}") if rng is None: rng = np.random.default_rng() burn = _resolve_post_burn(chains, post_burn) drop = set(int(i) for i in drop_chains) state_parts: List[np.ndarray] = [] logp_parts: List[np.ndarray] = [] logl_parts: List[np.ndarray] = [] chain_id_parts: List[np.ndarray] = [] step_parts: List[np.ndarray] = [] for c in chains: if c.chain_index in drop: continue if c.n_steps <= burn: continue state_parts.append(c.state[burn:]) logp_parts.append(c.log_posterior[burn:]) logl_parts.append(c.log_likelihood[burn:]) n_post = c.n_steps - burn chain_id_parts.append(np.full(n_post, c.chain_index, dtype=np.int64)) step_parts.append(c.step[burn:]) if not state_parts: raise ValueError( "All chains dropped or shorter than post_burn; nothing to sample." ) pool_state = np.concatenate(state_parts, axis=0) pool_logp = np.concatenate(logp_parts, axis=0) pool_logl = np.concatenate(logl_parts, axis=0) pool_chain_id = np.concatenate(chain_id_parts, axis=0) pool_step = np.concatenate(step_parts, axis=0) pool_size = pool_state.shape[0] if replace is None: replace = n > pool_size if not replace and n > pool_size: raise ValueError( f"Requested n={n} samples without replacement, but only " f"{pool_size} post-burn rows are available." ) indices = rng.choice(pool_size, size=n, replace=replace) return PosteriorSamples( state=pool_state[indices].copy(), log_posterior=pool_logp[indices].copy(), log_likelihood=pool_logl[indices].copy(), chain_index=pool_chain_id[indices].copy(), step=pool_step[indices].copy(), parameter_names=chains.config.parameter_names, )