Notes on Bayesian Inference #
tip
- git clone -> cd
- pip install -r requirements.txt (scipy, jax.numpy)
Understanding Uncertainty via Probabilities #
- sum rule: P(X) = P(X,Y) + P(X, \(\neg\) Y)
- product rule: P(X,Y) = P(X) \(\cdot\) P(Y | X)
- bayes theorem on data D:
- “discrete domain is just a subset of the continuous domain”
- this extends deductive reasoning (statistics) to plausible reasoning (probabilities)
Exponential Families and Conjugate Priors #
random variable X taking x values \(\subset \R^n\)
probability distribution for X with pdf of the following form:
for notational convenience, reparametrize natural parameters w := \(\eta(\theta)\) in terms of canonical parameters \(\theta\)
exponential families \((h(x), \phi(x))\) as the model for some data x guarantee automatic existence of conjugate priors, although not always tractable
conjugate priors allow analytic Bayesian inference of probabilistic models, if we can compute the partition function Z(w) of the likelihood and the one for the conjugate prior F( \(\alpha, \nu\) )
biggest challenge is finding the normalization constant
reduce Bayesian inference to:
- modelling: computing sufficient statistics \(\phi(x)\) and partition function Z(w)
- evaluating posterior: assessing log partition function F of the conjugate prior
if F is not tractable \(\Longrightarrow\) use Laplace approximations:
- find the mode \(ŵ\) of the posterior, by solving root-finding problems
evaluate the Hessian \(\Psi = \nabla_w \nabla_w^T \hspace{.05cm} \text{log} \hspace{.05cm} p(w|x)\) at the mode ŵ
approximate posterior as \(\mathcal{N}(w;ŵ, -\Psi^{-1}) \) and the conjugate log partition function as:
\[ F(\alpha', \nu') \approx \sqrt{(2\pi)^d \hspace{.05cm} \text{det}(-\Psi^{-1})} \cdot \text{exp}[ ŵ^T \cdot \alpha' - \hspace{.01cm} \text{log} \hspace{.01cm} Z(ŵ)^T \hspace{.01cm} \nu' ] \]
important
Laplace approximations reveal that Bayesian inference prioritizes capturing the geometry of the likelihood function around its peak (mode), rather than solely focusing on the prior distribution
Uncertainty is better understood as encompassing the multitude of potential solutions simultaneously, rather than fixating on a single point estimate. It’s about monitoring the breadth of plausible solutions. This means observing the volume of possibilities rather than pinpointing individual points
Gaussians #
Gaussian inference is linear algebra at its core
- products of Gaussians are Gaussians
\[ \mathcal{N}(x;a,A) \mathcal{N}(x;b,B) = \mathcal{N}(x;c,C) \mathcal{N}(a;b, A+B) \] \[ C = (A^{-1} + B^{-1})^{-1}, \quad c = C(A^{-1}a + B^{-1}b) \]
- linear maps/projections of Gaussians variables are Gaussian variables \[ p(z) = \mathcal{N}(z; \mu, \Sigma) \Longrightarrow p(Az) = \mathcal{N}(Az, A\mu, A\Sigma A^T) \]
- marginals of Gaussians are Gaussians \[ \int \mathcal{N} \left[ \begin{array}{c} x \\ y \end{array}; \begin{bmatrix} \mu_x \\ \mu_y \end{bmatrix}, \begin{bmatrix} \Sigma_{xx} & \Sigma_{xy} \\ \Sigma_{yx} & \Sigma_{yy} \end{bmatrix} \right] dy = \mathcal{N}(x;\mu_x, \Sigma_{xx}) \]
- linear conditionals of Gaussians are Gaussians \[ p(x | y) = \frac{p(x,y)}{p(y)} = \mathcal{N}(x; \mu_x + \Sigma_{xy}\Sigma_{yy}^{-1}(y - \mu_y),\Sigma_{xx}-\Sigma_{xy}\Sigma_{yy}^{-1}\Sigma_{yx}) \]
little b is a shift in the observation whilst little c is a shift of the prediction
if Gaussian prior over a random variable and observations are linearly related, then all conditionals, joints and marginals are Gaussian with means and covariances computable by linear algebra expressions – Bayesian inference becomes linear algebra
import dataclasses, jax, functools, scipy
from jax import numpy as jnp
@dataclasses.dataclass
class Gaussian: #Gaussian distribution w/ mean mu and covariance Sigma
mu: jnp.ndarray # shape(D, )
Sigma: jnp.ndarray # shape(D, D)
@functools.cached_property
def L(self):
"""Cholesky decomposition of the covariance matrix"""
return jnp.linalg.cholesky(self.Sigma) #lowlevel fortran libs
@functools.cached_property
def L_factor(self):
"""Cholesky factorization of the covariance matrix"""
return jax.scipy.linalg.cho_factor(self.Sigma, lower=True)
def condition(self, A, y, Lambda):
"""A: observation matrix, shape (N,D)
y: observation, shape(N,)
Lambda: observation noise covariance, shape(N,N)"""
Gram = A @ self.Sigma @ A.T + Lambda
L = jax.scipy.linalg.cho_factor(Gram, lower=True)
mu = self.mu + self.Sigma @ A.T @ jax.scipy.linalg.cho_solve(
L,y-A @ self.mu
)
Sigma = self.Sigma - self.Sigma @ A.T @ jax.scipy.linalg.cho_solve(
L, A @ self.Sigma
)
return Gaussian(mu=mu, Sigma=Sigma)
def sample(self, key, num_samples = 1):
"""Sample from the Gaussian"""
return jax.random.multivariate_normal(
key, mean = self.mu, cov = self.Sigma, shape = (num_samples,),
method = "svd"
) # singular value decomposition -> dimensionality reduction USV^T