"""Multivariate-normal fit to post-burn chains and the reparameterization config writer."""
from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, Sequence, Tuple, Union
from xml.etree import ElementTree as ET
import numpy as np
from ._chains import ChainSet
# ---------------------------------------------------------------------------
# Result dataclass
# ---------------------------------------------------------------------------
[docs]
@dataclass(frozen=True)
class MVNFit:
"""Multivariate-normal fit to post-burn samples.
Attributes
----------
mean:
``(n_params,)`` sample mean.
covariance:
``(n_params, n_params)`` sample covariance, symmetrised.
cholesky:
``(n_params, n_params)`` lower-triangular Cholesky factor of
:attr:`covariance`. Satisfies ``L @ L.T == covariance``.
parameter_names:
Parameter names along the axes.
Methods
-------
write_reparameterization_config:
Emit a Galacticus-style XML config that re-parameterizes the active
parameters in terms of independent unit-normal meta parameters.
"""
mean: np.ndarray
covariance: np.ndarray
cholesky: np.ndarray
parameter_names: Tuple[str, ...]
[docs]
def write_reparameterization_config(
self,
out_path: Union[str, "Path"],
*,
n_sigma: float = 5.0,
perturber_scale: float = 1.0e-5,
) -> Path:
"""Write a Galacticus reparameterization XML config.
For an *n*-parameter MVN fit with mean :math:`\\mu` and Cholesky
factor :math:`L`, the emitted config declares *n* active
``metaParameter{i}`` parameters with truncated unit-normal priors
(limits :math:`\\pm n_\\sigma`), and *n* derived parameters expressing
the original active parameters as
.. math::
x_i = \\mu_i + \\sum_j L_{ij} \\, m_j .
Re-running the MCMC against this config samples in coordinates where
the posterior is approximately spherical.
Parameters
----------
out_path:
Destination path.
n_sigma:
Truncation half-width for the meta-parameter priors, in units of
their (unit) standard deviation. Defaults to ``5.0``.
perturber_scale:
Cauchy ``scale`` of the per-meta-parameter perturber. Defaults to
``1.0e-5`` to match the Galacticus reference.
Returns
-------
pathlib.Path
Resolved path of the written file.
"""
if n_sigma <= 0:
raise ValueError(f"n_sigma must be positive; got {n_sigma!r}")
if perturber_scale <= 0:
raise ValueError(
f"perturber_scale must be positive; got {perturber_scale!r}"
)
n = len(self.parameter_names)
if n == 0:
raise ValueError("MVNFit has no parameters to write.")
root = ET.Element("parameters")
for i in range(n):
mp = ET.SubElement(root, "modelParameter", value="active")
ET.SubElement(mp, "name", value=f"metaParameter{i}")
prior = ET.SubElement(mp, "distributionFunction1DPrior", value="normal")
ET.SubElement(prior, "mean", value="0.0")
ET.SubElement(prior, "variance", value="1.0")
ET.SubElement(prior, "limitLower", value=f"{-n_sigma:.6g}")
ET.SubElement(prior, "limitUpper", value=f"{n_sigma:.6g}")
ET.SubElement(mp, "operatorUnaryMapper", value="identity")
pert = ET.SubElement(
mp, "distributionFunction1DPerturber", value="cauchy"
)
ET.SubElement(pert, "median", value="0.0")
ET.SubElement(pert, "scale", value=f"{perturber_scale:.6g}")
for i, name in enumerate(self.parameter_names):
mp = ET.SubElement(root, "modelParameter", value="derived")
ET.SubElement(mp, "name", value=name)
terms = [f"{self.mean[i]:.16g}"]
for j in range(n):
coef = float(self.cholesky[i, j])
if coef == 0.0:
continue
sign = "+" if coef >= 0 else "-"
terms.append(f"{sign}{abs(coef):.16g}*%[metaParameter{j}]")
definition = "".join(terms)
ET.SubElement(mp, "definition", value=definition)
tree = ET.ElementTree(root)
_indent_inplace(tree.getroot(), level=0, space=" ")
out = Path(out_path)
out.parent.mkdir(parents=True, exist_ok=True)
tree.write(out, encoding="utf-8", xml_declaration=True)
return out.resolve()
def _indent_inplace(elem: ET.Element, level: int = 0, space: str = " ") -> None:
"""In-place pretty-indenter equivalent to ``xml.etree.ElementTree.indent``.
Provided for Python 3.8, where ``ET.indent`` is not available.
"""
i = "\n" + level * space
if len(elem):
if not elem.text or not elem.text.strip():
elem.text = i + space
if not elem.tail or not elem.tail.strip():
elem.tail = i
for child in elem:
_indent_inplace(child, level + 1, space)
if not child.tail or not child.tail.strip():
child.tail = i
else:
if level and (not elem.tail or not elem.tail.strip()):
elem.tail = i
# ---------------------------------------------------------------------------
# Public entry point
# ---------------------------------------------------------------------------
[docs]
def multivariate_normal_fit(
chains: ChainSet,
*,
post_burn: Optional[int] = None,
drop_chains: Sequence[int] = (),
) -> MVNFit:
"""Fit a multivariate normal to the post-burn concatenated chain.
Parameters
----------
chains:
:class:`ChainSet`.
post_burn:
Number of leading rows to skip per chain. ``None`` triggers
automatic detection via :func:`gelman_rubin` /
:func:`convergence_step`.
drop_chains:
``chain_index`` values to exclude.
Returns
-------
MVNFit
Raises
------
ValueError
If fewer than ``n_params + 1`` post-burn samples remain (so that the
sample covariance is rank-deficient).
np.linalg.LinAlgError
If the sample covariance is not positive-definite (which can happen
for parameters that are degenerate post-burn). Drop the offending
parameter or supply more samples.
"""
from ._convergence import _resolve_post_burn
burn = _resolve_post_burn(chains, post_burn)
drop = set(int(i) for i in drop_chains)
parts = [
c.state[burn:]
for c in chains
if c.chain_index not in drop and c.n_steps > burn
]
if not parts:
raise ValueError(
"No post-burn samples available — every chain was dropped or "
"shorter than post_burn."
)
samples = np.concatenate(parts, axis=0)
n_samples, n_params = samples.shape
if n_samples < n_params + 1:
raise ValueError(
f"Need at least n_params + 1 = {n_params + 1} post-burn samples "
f"to fit a multivariate normal; got {n_samples}."
)
mean = samples.mean(axis=0)
cov = np.cov(samples, rowvar=False)
cov = 0.5 * (cov + cov.T)
L = np.linalg.cholesky(cov)
return MVNFit(
mean=mean,
covariance=cov,
cholesky=L,
parameter_names=chains.config.parameter_names,
)