Bayesian Inference

Notes on Bayesian Inference #

tip

  • git clone -> cd
  • pip install -r requirements.txt (scipy, jax.numpy)

Understanding Uncertainty via Probabilities #

  • sum rule: P(X) = P(X,Y) + P(X, \(\neg\) Y)
  • product rule: P(X,Y) = P(X) \(\cdot\) P(Y | X)
  • bayes theorem on data D:
\[\underbrace{P(X | D)}_{\text{posterior of X given D}} \hspace{.1cm} = \frac{\overbrace{P(D | X)}^{\text{likelihood of X under D}} \hspace{.1cm} \cdot \hspace{.1cm} \overbrace{P(X)}^{\text{prior of X}}}{\underbrace{P(D)}_{\text{marginalization or evidence of the model}}}\]
  • “discrete domain is just a subset of the continuous domain”
  • this extends deductive reasoning (statistics) to plausible reasoning (probabilities)

Exponential Families and Conjugate Priors #

  • random variable X taking x values \(\subset \R^n\)

  • probability distribution for X with pdf of the following form:

\[p_{w}(x) = \overbrace{h(x)}^{\text{base measure}} \, \text{exp} \left( \overbrace{\phi(x)^T}^{\text{sufficient statistics}} \cdot \underbrace{w}_{\text{natural parameters}} - \text{log} \, \overbrace{Z(w)}^{\text{partition function}} \right)\] \[= \frac{h(x)}{Z(w)} e^{\phi(x)^T w} = p(x | w)\]
  • for notational convenience, reparametrize natural parameters w := \(\eta(\theta)\) in terms of canonical parameters \(\theta\)

  • exponential families \((h(x), \phi(x))\) as the model for some data x guarantee automatic existence of conjugate priors, although not always tractable

  • conjugate priors allow analytic Bayesian inference of probabilistic models, if we can compute the partition function Z(w) of the likelihood and the one for the conjugate prior F( \(\alpha, \nu\) )

  • biggest challenge is finding the normalization constant

  • reduce Bayesian inference to:

    • modelling: computing sufficient statistics \(\phi(x)\) and partition function Z(w)
    • evaluating posterior: assessing log partition function F of the conjugate prior
  • if F is not tractable \(\Longrightarrow\) use Laplace approximations:

    • find the mode \(ŵ\) of the posterior, by solving root-finding problems
    \[ \nabla_{w} \hspace{.05cm} \text{log} \hspace{.05cm} p (w | x) = \frac{\alpha + \sum_{i=1}^n \phi(x_{i})}{\nu + n} \]
    • evaluate the Hessian \(\Psi = \nabla_w \nabla_w^T \hspace{.05cm} \text{log} \hspace{.05cm} p(w|x)\) at the mode ŵ

    • approximate posterior as \(\mathcal{N}(w;ŵ, -\Psi^{-1}) \) and the conjugate log partition function as:

      \[ F(\alpha', \nu') \approx \sqrt{(2\pi)^d \hspace{.05cm} \text{det}(-\Psi^{-1})} \cdot \text{exp}[ ŵ^T \cdot \alpha' - \hspace{.01cm} \text{log} \hspace{.01cm} Z(ŵ)^T \hspace{.01cm} \nu' ] \]

important

  • Laplace approximations reveal that Bayesian inference prioritizes capturing the geometry of the likelihood function around its peak (mode), rather than solely focusing on the prior distribution

  • Uncertainty is better understood as encompassing the multitude of potential solutions simultaneously, rather than fixating on a single point estimate. It’s about monitoring the breadth of plausible solutions. This means observing the volume of possibilities rather than pinpointing individual points

Gaussians #

  • Gaussian inference is linear algebra at its core

    • products of Gaussians are Gaussians

    \[ \mathcal{N}(x;a,A) \mathcal{N}(x;b,B) = \mathcal{N}(x;c,C) \mathcal{N}(a;b, A+B) \] \[ C = (A^{-1} + B^{-1})^{-1}, \quad c = C(A^{-1}a + B^{-1}b) \]

    • linear maps/projections of Gaussians variables are Gaussian variables \[ p(z) = \mathcal{N}(z; \mu, \Sigma) \Longrightarrow p(Az) = \mathcal{N}(Az, A\mu, A\Sigma A^T) \]
    • marginals of Gaussians are Gaussians \[ \int \mathcal{N} \left[ \begin{array}{c} x \\ y \end{array}; \begin{bmatrix} \mu_x \\ \mu_y \end{bmatrix}, \begin{bmatrix} \Sigma_{xx} & \Sigma_{xy} \\ \Sigma_{yx} & \Sigma_{yy} \end{bmatrix} \right] dy = \mathcal{N}(x;\mu_x, \Sigma_{xx}) \]
    • linear conditionals of Gaussians are Gaussians \[ p(x | y) = \frac{p(x,y)}{p(y)} = \mathcal{N}(x; \mu_x + \Sigma_{xy}\Sigma_{yy}^{-1}(y - \mu_y),\Sigma_{xx}-\Sigma_{xy}\Sigma_{yy}^{-1}\Sigma_{yx}) \]
  • little b is a shift in the observation whilst little c is a shift of the prediction

  • if Gaussian prior over a random variable and observations are linearly related, then all conditionals, joints and marginals are Gaussian with means and covariances computable by linear algebra expressions – Bayesian inference becomes linear algebra

import dataclasses, jax, functools, scipy
from jax import numpy as jnp

@dataclasses.dataclass
class Gaussian: #Gaussian distribution w/ mean mu and covariance Sigma
    mu: jnp.ndarray # shape(D, )
    Sigma: jnp.ndarray # shape(D, D)

    @functools.cached_property
    def L(self): 
        """Cholesky decomposition of the covariance matrix"""
        return jnp.linalg.cholesky(self.Sigma) #lowlevel fortran libs

    @functools.cached_property
    def L_factor(self):
        """Cholesky factorization of the covariance matrix"""
        return jax.scipy.linalg.cho_factor(self.Sigma, lower=True)

    def condition(self, A, y, Lambda):
        """A: observation matrix, shape (N,D)
           y: observation, shape(N,)
           Lambda: observation noise covariance, shape(N,N)"""
        Gram = A @ self.Sigma @ A.T + Lambda
        L = jax.scipy.linalg.cho_factor(Gram, lower=True)
        mu = self.mu + self.Sigma @ A.T @ jax.scipy.linalg.cho_solve(
            L,y-A @ self.mu
            )
        Sigma = self.Sigma - self.Sigma @ A.T @ jax.scipy.linalg.cho_solve(
            L, A @ self.Sigma
            )

        return Gaussian(mu=mu, Sigma=Sigma)

    def sample(self, key, num_samples = 1):
        """Sample from the Gaussian"""
        return jax.random.multivariate_normal(
            key, mean = self.mu, cov = self.Sigma, shape = (num_samples,),
             method = "svd"
        ) # singular value decomposition -> dimensionality reduction USV^T