AgentSkillsCN

probabilistic-programming

利用 PyMC、NumPyro 构建贝叶斯模型,并辅以严谨的 MCMC 诊断分析。

SKILL.md
--- frontmatter
name: probabilistic-programming
description: Build Bayesian models with PyMC, NumPyro, and rigorous MCMC diagnostics

Probabilistic Programming

PyMC model patterns, NumPyro GPU inference, prior selection, MCMC diagnostics, posterior predictive checks, variational inference, and model comparison.

PPL Decision Table

Model ComplexitySpeed NeedPPLWhy
Standard regression, hierarchicalModeratePyMCMature API, ArviZ integration, good defaults
Large data, GPU requiredHighNumPyroJAX backend, NUTS on GPU, fastest MCMC
Deep generative modelsHighPyroPyTorch backend, flexible guide design
Simple conjugate modelsLowStan (via CmdStanPy)Gold standard HMC, extensive diagnostics
Educational / prototypingLowPyMCReadable syntax, best documentation
Production servingHighNumPyroJIT-compiled, minimal overhead
Time series (structural)ModerateOrbit / PyMCSpecialized APIs for DLM, ETS
Gaussian processesModerateGPyTorch / PyMCGPyTorch for scalable exact GPs

PyMC Model Patterns

Linear Regression

python
import pymc as pm
import numpy as np
import arviz as az

def bayesian_linear_regression(X: np.ndarray, y: np.ndarray) -> az.InferenceData:
    n_features = X.shape[1]
    with pm.Model() as model:
        intercept = pm.Normal("intercept", mu=0, sigma=10)
        betas = pm.Normal("betas", mu=0, sigma=5, shape=n_features)
        sigma = pm.HalfNormal("sigma", sigma=5)
        mu = intercept + pm.math.dot(X, betas)
        pm.Normal("y_obs", mu=mu, sigma=sigma, observed=y)
        idata = pm.sample(
            draws=2000, tune=1000, chains=4,
            target_accept=0.9, random_seed=42,
        )
    return idata

Hierarchical / Multilevel Model

python
def hierarchical_model(group_idx: np.ndarray, X: np.ndarray,
                       y: np.ndarray, n_groups: int) -> az.InferenceData:
    """Partial pooling with non-centered parameterization."""
    with pm.Model() as model:
        mu_alpha = pm.Normal("mu_alpha", mu=0, sigma=10)
        sigma_alpha = pm.HalfNormal("sigma_alpha", sigma=5)
        # Non-centered: offset * sigma + mu avoids funnel geometry
        alpha_offset = pm.Normal("alpha_offset", mu=0, sigma=1, shape=n_groups)
        alpha = pm.Deterministic("alpha", mu_alpha + sigma_alpha * alpha_offset)
        beta = pm.Normal("beta", mu=0, sigma=5)
        sigma = pm.HalfNormal("sigma", sigma=5)
        mu = alpha[group_idx] + beta * X
        pm.Normal("y_obs", mu=mu, sigma=sigma, observed=y)
        idata = pm.sample(2000, tune=1000, chains=4, target_accept=0.95)
    return idata

Gaussian Process

python
def gp_regression(X: np.ndarray, y: np.ndarray) -> az.InferenceData:
    with pm.Model() as model:
        length_scale = pm.Gamma("length_scale", alpha=2, beta=1)
        amplitude = pm.HalfNormal("amplitude", sigma=2)
        noise = pm.HalfNormal("noise", sigma=1)
        cov = amplitude**2 * pm.gp.cov.ExpQuad(input_dim=1, ls=length_scale)
        gp = pm.gp.Marginal(cov_func=cov)
        gp.marginal_likelihood("y_obs", X=X[:, None], y=y, sigma=noise)
        idata = pm.sample(1000, tune=1000, chains=4)
    return idata

NumPyro for GPU-Accelerated Inference

NUTS Sampling

python
import jax
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, Predictive

numpyro.set_host_device_count(4)

def numpyro_regression(X: jnp.ndarray, y: jnp.ndarray):
    def model(X, y=None):
        intercept = numpyro.sample("intercept", dist.Normal(0, 10))
        betas = numpyro.sample("betas", dist.Normal(0, 5).expand([X.shape[1]]))
        sigma = numpyro.sample("sigma", dist.HalfNormal(5))
        mu = intercept + X @ betas
        numpyro.sample("y_obs", dist.Normal(mu, sigma), obs=y)

    kernel = NUTS(model, target_accept_prob=0.9)
    mcmc = MCMC(kernel, num_warmup=1000, num_samples=2000, num_chains=4)
    mcmc.run(jax.random.PRNGKey(42), X, y)
    return az.from_numpyro(mcmc)

Stochastic Variational Inference (SVI)

python
from numpyro.infer import SVI, Trace_ELBO
from numpyro.infer.autoguide import AutoNormal

def numpyro_svi(X: jnp.ndarray, y: jnp.ndarray, num_steps: int = 5000):
    """VI for fast approximate posterior when MCMC is too slow."""
    def model(X, y=None):
        intercept = numpyro.sample("intercept", dist.Normal(0, 10))
        betas = numpyro.sample("betas", dist.Normal(0, 5).expand([X.shape[1]]))
        sigma = numpyro.sample("sigma", dist.HalfNormal(5))
        numpyro.sample("y_obs", dist.Normal(intercept + X @ betas, sigma), obs=y)

    guide = AutoNormal(model)
    svi = SVI(model, guide, numpyro.optim.Adam(0.01), loss=Trace_ELBO())
    result = svi.run(jax.random.PRNGKey(0), num_steps, X, y)
    predictive = Predictive(guide, params=result.params, num_samples=2000)
    return predictive(jax.random.PRNGKey(1), X, y)

Prior Selection Guide

Parameter TypeRecommended PriorRationale
Regression interceptNormal(0, 10)Weakly informative, centered on zero
Regression slopeNormal(0, 5)Allows moderate effects
Scale / std devHalfNormal(5) or Exponential(1)Positive, shrinks toward zero
Correlation matrixLKJCholesky(eta=2)eta=2 weakly favors identity
ProportionBeta(2, 2)Weakly informative, avoids 0/1 edges
Count rateGamma(2, 0.5)Positive, weakly informative
GP length scaleGamma(2, 1) or InverseGamma(5, 5)Prevents near-zero or infinite
Degrees of freedom (Student-t)Gamma(2, 0.1)Allows heavy tails, weakly informative

Prior Predictive Workflow

python
with pm.Model() as model:
    # ... define priors and likelihood ...
    prior_pred = pm.sample_prior_predictive(samples=500)
    # Check: are prior predictions in a plausible range?
    # If y is height in cm, prior predictions of 1e6 cm = priors too wide
    az.plot_ppc(prior_pred, group="prior")

MCMC Diagnostics with ArviZ

python
import arviz as az

def full_diagnostics(idata: az.InferenceData):
    summary = az.summary(idata, hdi_prob=0.94)
    print(summary)
    rhat = az.rhat(idata)        # all values should be < 1.01
    ess = az.ess(idata)          # bulk ESS > 400 per chain
    divergences = idata.sample_stats.diverging.sum().values
    print(f"Divergences: {divergences}")  # should be 0
    az.plot_trace(idata, var_names=["intercept", "betas", "sigma"])
    az.plot_rank(idata, var_names=["betas"])  # rank plots > trace
    az.plot_energy(idata)
    return summary

def posterior_predictive_check(idata: az.InferenceData, model):
    with model:
        pm.sample_posterior_predictive(idata, extend_inferencedata=True)
    az.plot_ppc(idata, num_pp_samples=100, kind="cumulative")
    az.plot_loo_pit(idata, y="y_obs")  # should be uniform

Model Comparison

python
def compare_models(models: dict[str, az.InferenceData]) -> None:
    """Compare via LOO-CV (preferred over WAIC)."""
    for name, idata in models.items():
        loo = az.loo(idata, pointwise=True)
        print(f"{name}: elpd_loo={loo.elpd_loo:.1f} +/- {loo.se:.1f}")
    comparison = az.compare(models, ic="loo")
    print(comparison)
    az.plot_compare(comparison)

# Enable log_likelihood for comparison
# PyMC: pm.sample(..., idata_kwargs={"log_likelihood": True})
# NumPyro: numpyro.infer.log_likelihood(model_fn, mcmc.get_samples(), X, y=y)

Gotchas

  • Non-centered parameterization: for hierarchical models, use offset * sigma + mu not Normal(mu, sigma). Centered causes funnel geometries and divergences.
  • Divergences are not ignorable: even 1 means biased posterior. Increase target_accept (0.95-0.99), reparameterize, or simplify.
  • R-hat must be < 1.01: values above 1.05 indicate chains haven't mixed. Run longer or reparameterize.
  • ESS: bulk ESS > 400 per chain for means, tail ESS > 400 for credible intervals. Low tail ESS = unreliable HDI.
  • plot_rank > plot_trace: rank plots are more reliable for detecting convergence issues.
  • PyMC auto-assigns sampler: discrete parameters silently fall back to Metropolis. Use pm.sample(step=pm.NUTS()) explicitly.
  • NumPyro requires JAX arrays: pass jnp.array() not numpy. Shape mismatches produce cryptic tracing errors.
  • WAIC vs LOO-CV: prefer LOO. If Pareto k > 0.7, refit with moment matching or use k-fold CV.
  • Prior predictive checks: always run pm.sample_prior_predictive() first. Absurd prior predictions = priors too wide.