"""MCMC convergence diagnostics: Gelman-Rubin, Geweke, outlier-chain detection."""
from __future__ import annotations
import warnings
from dataclasses import dataclass
from typing import Iterable, Optional, Sequence, Tuple
import numpy as np
from ._chains import ChainSet
from ._grubbs import iterative_grubbs
# ---------------------------------------------------------------------------
# Gelman-Rubin
# ---------------------------------------------------------------------------
[docs]
@dataclass(frozen=True)
class RhatResult:
"""Result of :func:`gelman_rubin`.
Attributes
----------
steps:
``(n_eval,)`` 1-D array of truncation step counts at which Rhat was
computed (i.e. each entry ``s`` means "use the first ``s`` rows of
every chain"). These are 1-based step counts so the smallest value is
the chosen ``min_steps``.
Rhat_c:
``(n_eval, n_params)`` array of Brooks-Gelman corrected potential-scale
reduction factors.
R_interval:
``(n_eval, n_params)`` array of non-parametric interval-length ratios
(mixed-chain credible interval / mean per-chain credible interval) at
the chosen ``alpha_interval``.
parameter_names:
Names of the parameters along ``axis=1``.
alpha_interval:
Significance level used to compute ``R_interval``.
chains_used:
``chain_index`` values of the chains that contributed (after
``drop_chains`` was applied).
Methods
-------
Rhat_c_max:
Per-step max-over-parameters of :attr:`Rhat_c`, useful as the input to
:func:`convergence_step`.
"""
steps: np.ndarray
Rhat_c: np.ndarray
R_interval: np.ndarray
parameter_names: Tuple[str, ...]
alpha_interval: float
chains_used: Tuple[int, ...]
[docs]
def Rhat_c_max(self) -> np.ndarray:
"""Return ``(n_eval,)`` max-over-parameters of :attr:`Rhat_c`."""
return self.Rhat_c.max(axis=1)
[docs]
def gelman_rubin(
chains: ChainSet,
*,
drop_chains: Sequence[int] = (),
step_grid: Optional[Sequence[int]] = None,
n_grid: int = 200,
min_steps: int = 10,
alpha_interval: float = 0.15,
) -> RhatResult:
"""Brooks-Gelman corrected Rhat as a function of simulation step.
For each chosen truncation point ``s`` the first ``s`` rows of every
surviving chain are used to compute the standard between-chain (``B``)
and within-chain (``W``) variances and the Brooks-Gelman corrected
potential-scale reduction factor :math:`\\hat{R}_c`. The non-parametric
interval-length ratio :math:`R_{\\rm interval}` (Brooks & Gelman 1998
section 1.3) is also computed at the same evaluation points.
Parameters
----------
chains:
:class:`ChainSet` to evaluate. Must contain at least two non-dropped
chains and at least ``min_steps`` rows per chain.
drop_chains:
Iterable of ``chain_index`` values to exclude before computing.
Use this with the indices returned by :func:`outlier_chains`.
step_grid:
Optional explicit 1-D iterable of truncation step counts (1-based).
When given, ``n_grid`` and ``min_steps`` are ignored.
n_grid:
Number of evenly-spaced evaluation points to use when ``step_grid`` is
``None``. Capped at the shortest surviving chain length minus
``min_steps`` + 1.
min_steps:
Smallest truncation step count to evaluate. Must be ``>= 2``.
alpha_interval:
Two-sided significance level for ``R_interval`` (default 0.15, i.e.
85 % credible intervals — matches the Galacticus Perl reference).
Returns
-------
RhatResult
Raises
------
ValueError
If fewer than two chains survive ``drop_chains`` or ``min_steps`` is
too small.
"""
if min_steps < 2:
raise ValueError(f"min_steps must be >= 2; got {min_steps}")
drop = set(int(i) for i in drop_chains)
keep = [c for c in chains if c.chain_index not in drop]
if len(keep) < 2:
raise ValueError(
f"gelman_rubin requires at least 2 chains; got {len(keep)} "
f"after dropping {sorted(drop)!r}."
)
n_min = min(c.n_steps for c in keep)
if n_min < min_steps:
raise ValueError(
f"Shortest surviving chain has {n_min} steps; min_steps={min_steps}."
)
if step_grid is None:
n_eval = min(int(n_grid), n_min - min_steps + 1)
steps = np.unique(
np.linspace(min_steps, n_min, n_eval, dtype=int)
)
else:
steps = np.asarray(list(step_grid), dtype=int)
if (steps < 2).any():
raise ValueError("step_grid entries must all be >= 2.")
if (steps > n_min).any():
raise ValueError(
f"step_grid contains values exceeding the shortest chain "
f"length ({n_min})."
)
n_params = chains.n_params
# Stack to (n_chains, n_min_overall, n_params) so a single fancy index
# gives us the truncated view at any step.
stacked = np.stack([c.state[:n_min] for c in keep], axis=0)
Rhat_c = np.empty((steps.size, n_params), dtype=float)
R_interval = np.empty((steps.size, n_params), dtype=float)
lo_q = alpha_interval / 2.0
hi_q = 1.0 - alpha_interval / 2.0
for i, s in enumerate(steps):
sub = stacked[:, :s, :] # (m, s, n_params)
Rhat_c[i] = _brooks_gelman_corrected(sub)
R_interval[i] = _interval_ratio(sub, lo_q, hi_q)
return RhatResult(
steps=steps,
Rhat_c=Rhat_c,
R_interval=R_interval,
parameter_names=chains.config.parameter_names,
alpha_interval=alpha_interval,
chains_used=tuple(c.chain_index for c in keep),
)
def _brooks_gelman_corrected(sub: np.ndarray) -> np.ndarray:
"""Brooks-Gelman corrected Rhat for *sub* of shape ``(m, n, n_params)``.
Returns a ``(n_params,)`` array. Per-parameter computation follows
Brooks & Gelman 1998, with the ``(d+3)/(d+1)`` correction applied.
"""
m, n, _ = sub.shape
chain_means = sub.mean(axis=1) # (m, n_params)
chain_vars = sub.var(axis=1, ddof=1) # (m, n_params) — s_j^2
grand_mean = chain_means.mean(axis=0) # (n_params,)
# B = n * sample variance of chain means across chains, denom (m-1).
B = n * chain_means.var(axis=0, ddof=1) # (n_params,)
# W = mean of within-chain variances.
W = chain_vars.mean(axis=0) # (n_params,)
# Posterior-variance estimate.
Vhat = (n - 1) / n * W + (m + 1) / (m * n) * B
# Variance of Vhat (Brooks-Gelman eq 4.5).
var_s2 = chain_vars.var(axis=0, ddof=1)
cov_s2_xbar = _sample_covariance(chain_vars, chain_means)
cov_s2_xbar2 = _sample_covariance(chain_vars, chain_means ** 2)
term_W = ((n - 1) / n) ** 2 * (1.0 / m) * var_s2
term_B = ((m + 1) / (m * n)) ** 2 * (2.0 * B ** 2) / (m - 1)
term_cov = (
2.0 * (m + 1) * (n - 1) / (m ** 2 * n ** 2) * (n / m)
* (cov_s2_xbar2 - 2.0 * grand_mean * cov_s2_xbar)
)
var_Vhat = term_W + term_B + term_cov
# Degrees of freedom and corrected Rhat. Guard against tiny / negative
# var_Vhat (which can happen in pathological synthetic cases).
safe_var = np.where(var_Vhat > 0, var_Vhat, np.inf)
d = 2.0 * Vhat ** 2 / safe_var
# Avoid div-by-zero when W is exactly 0 (constant chain). Fall back to NaN
# so the user can detect the degenerate case.
safe_W = np.where(W > 0, W, np.nan)
rhat_sq = (m + 1) / m * Vhat / safe_W - (n - 1) / (n * m)
rhat_sq = np.where(rhat_sq > 0, rhat_sq, np.nan)
rhat = np.sqrt(rhat_sq)
return (d + 3) / (d + 1) * rhat
def _sample_covariance(a: np.ndarray, b: np.ndarray) -> np.ndarray:
"""Per-column sample covariance of two ``(m, n_params)`` arrays."""
m = a.shape[0]
a_dev = a - a.mean(axis=0)
b_dev = b - b.mean(axis=0)
return (a_dev * b_dev).sum(axis=0) / (m - 1)
def _interval_ratio(sub: np.ndarray, lo_q: float, hi_q: float) -> np.ndarray:
"""Mixed-chain credible-interval length divided by the mean per-chain length.
Parameters
----------
sub:
``(m, n, n_params)`` array.
lo_q, hi_q:
Lower and upper quantile probabilities in ``[0, 1]``.
Returns
-------
np.ndarray
Per-parameter ratio. ``NaN`` where any chain has zero spread.
"""
m, n, _ = sub.shape
mixed = sub.reshape(m * n, -1)
mixed_lo = np.quantile(mixed, lo_q, axis=0)
mixed_hi = np.quantile(mixed, hi_q, axis=0)
mixed_len = mixed_hi - mixed_lo
per_chain_lo = np.quantile(sub, lo_q, axis=1) # (m, n_params)
per_chain_hi = np.quantile(sub, hi_q, axis=1)
per_chain_len = per_chain_hi - per_chain_lo
mean_per_chain = per_chain_len.mean(axis=0)
safe = np.where(mean_per_chain > 0, mean_per_chain, np.nan)
return mixed_len / safe
# ---------------------------------------------------------------------------
# Convergence step
# ---------------------------------------------------------------------------
[docs]
def convergence_step(
rhat_max: np.ndarray,
*,
threshold: float = 1.1,
sustained_for: int = 1,
) -> Optional[int]:
"""Index into the Rhat grid at which convergence is first declared.
Searches for the smallest index ``i`` such that every entry of
``rhat_max[i : i + sustained_for]`` is at or below ``threshold``.
Parameters
----------
rhat_max:
1-D array of (max-over-parameters) Rhat values, e.g.
:meth:`RhatResult.Rhat_c_max`.
threshold:
Convergence threshold. Defaults to ``1.1``.
sustained_for:
Number of consecutive grid points that must all be below the threshold
before convergence is declared. Defaults to ``1`` (strict first
crossing).
Returns
-------
int or None
Grid index at which convergence is first sustained, or ``None`` if
the threshold is never met.
Notes
-----
Use :attr:`RhatResult.steps` to translate the returned grid index to a
simulation-step count.
"""
arr = np.asarray(rhat_max, dtype=float)
if arr.ndim != 1:
raise ValueError(f"rhat_max must be 1-D; got shape {arr.shape!r}")
if sustained_for < 1:
raise ValueError(f"sustained_for must be >= 1; got {sustained_for}")
n = arr.size
if n < sustained_for:
return None
below = arr <= threshold
for i in range(n - sustained_for + 1):
if below[i : i + sustained_for].all():
return i
return None
# ---------------------------------------------------------------------------
# Geweke
# ---------------------------------------------------------------------------
[docs]
def geweke(
chains: ChainSet,
*,
first: float = 0.1,
last: float = 0.5,
) -> np.ndarray:
"""Per-chain Geweke z-scores comparing the means of two chain segments.
For each chain and each parameter, returns
.. math::
z = \\frac{\\bar{x}_1 - \\bar{x}_2}{\\sqrt{s^2_1/n_1 + s^2_2/n_2}}
where segment 1 covers the first ``first`` fraction of the chain and
segment 2 covers the last ``last`` fraction. Large ``|z|`` for any
parameter suggests the chain has not yet reached a stationary
distribution — useful when the chains were started from an
under-dispersed state (which makes Gelman-Rubin uninformative).
Parameters
----------
chains:
:class:`ChainSet`.
first, last:
Fractions in ``(0, 1)`` for the lengths of the two segments. By
default ``first=0.1`` and ``last=0.5`` (Geweke's original
recommendation).
Returns
-------
np.ndarray
``(n_chains, n_params)`` z-score array. Chains shorter than 4 rows
in either segment yield ``NaN``.
Notes
-----
The variance estimator used here is the simple sample variance, which
treats each draw as independent. Autocorrelated chains will produce
artificially-large ``|z|``; once a proper integrated-autocorrelation-time
estimator lands (Phase 3) this can be inflated by the ACL to recover the
classical spectral-density-at-zero variant.
"""
if not (0.0 < first < 1.0):
raise ValueError(f"first must be in (0, 1); got {first}")
if not (0.0 < last < 1.0):
raise ValueError(f"last must be in (0, 1); got {last}")
if first + last > 1.0:
raise ValueError(
f"first + last must be <= 1; got first={first}, last={last}"
)
n_params = chains.n_params
out = np.full((len(chains), n_params), np.nan)
for i, c in enumerate(chains):
n = c.n_steps
n1 = int(n * first)
n2 = int(n * last)
if n1 < 2 or n2 < 2:
continue
seg1 = c.state[:n1]
seg2 = c.state[-n2:]
mu1 = seg1.mean(axis=0)
mu2 = seg2.mean(axis=0)
v1 = seg1.var(axis=0, ddof=1) / n1
v2 = seg2.var(axis=0, ddof=1) / n2
denom = np.sqrt(v1 + v2)
z = np.where(denom > 0, (mu1 - mu2) / np.where(denom > 0, denom, 1.0), np.nan)
out[i] = z
return out
# ---------------------------------------------------------------------------
# Outlier chains
# ---------------------------------------------------------------------------
[docs]
def outlier_chains(
chains: ChainSet,
*,
alpha: float = 0.05,
max_outliers: int = 10,
parameters: Optional[Iterable[str]] = None,
) -> Tuple[int, ...]:
"""Iterative two-sided Grubbs test on each chain's final state.
Each chain contributes its last row (the most recent state) as a single
multivariate point. The Grubbs test is applied iteratively over the
active chains, dropping the chain whose maximum per-parameter deviation
exceeds the critical value at each step, until none exceed it or
``max_outliers`` chains have been removed.
Parameters
----------
chains:
:class:`ChainSet`. Must contain at least three chains.
alpha:
Two-sided significance level. Defaults to ``0.05`` to match the
Galacticus Perl reference's hard-coded value.
max_outliers:
Maximum number of chains to declare as outliers.
parameters:
Optional iterable of parameter names to restrict the test to a
subset. Unknown names raise :class:`KeyError`.
Returns
-------
tuple of int
``chain_index`` values of the chains flagged as outliers, in the
order they were removed.
"""
if len(chains) < 3:
return ()
# Restrict to selected parameter columns if requested.
if parameters is None:
cols = slice(None)
else:
wanted = list(parameters)
index_by_name = {p.name: i for i, p in enumerate(chains.config.parameters)}
try:
cols = [index_by_name[name] for name in wanted]
except KeyError as e:
raise KeyError(
f"Unknown parameter name {e.args[0]!r}; "
f"available: {list(index_by_name)!r}"
) from None
finals = np.stack([c.state[-1] for c in chains], axis=0)
points = finals[:, cols] if isinstance(cols, list) else finals
if points.ndim == 1:
points = points.reshape(-1, 1)
flagged_rows = iterative_grubbs(points, alpha=alpha, max_outliers=max_outliers)
return tuple(int(chains[i].chain_index) for i in flagged_rows)
# ---------------------------------------------------------------------------
# Burn-in resolution
# ---------------------------------------------------------------------------
def _resolve_post_burn(chains: ChainSet, post_burn: Optional[int]) -> int:
"""Resolve a ``post_burn`` argument: ``None`` triggers convergence detection.
When *post_burn* is ``None``, runs :func:`gelman_rubin` with default
settings on *chains* and returns the step count from
:func:`convergence_step` at threshold ``1.1`` and ``sustained_for=1``.
If convergence is not reached on the default grid (or if the chain set is
too small for Gelman-Rubin), a :class:`UserWarning` is emitted and ``0``
is returned so the caller can proceed with the full chain.
Returns
-------
int
Number of leading rows to drop in each chain.
"""
if post_burn is not None:
if post_burn < 0:
raise ValueError(f"post_burn must be non-negative; got {post_burn!r}")
return int(post_burn)
try:
result = gelman_rubin(chains)
except ValueError as exc:
warnings.warn(
f"Auto burn-in detection failed ({exc}); using post_burn=0.",
UserWarning,
stacklevel=3,
)
return 0
idx = convergence_step(result.Rhat_c_max())
if idx is None:
warnings.warn(
"Auto burn-in detection did not find convergence on the default "
"grid; using post_burn=0. Pass an explicit post_burn= to silence "
"this warning.",
UserWarning,
stacklevel=3,
)
return 0
return int(result.steps[idx])