← Back to blogs

Diffusion Probabilistic Models: From Ink Drops to AI

Dec 27, 2024
Diffusion ModelsProbabilistic ModelsDeep Learning

Imagine dropping ink into water. As time passes, the ink gradually diffuses, spreading throughout the water until it becomes completely dispersed. Now, imagine reversing this process — starting with the dispersed ink and reconstructing the original drop. This is the fundamental intuition behind diffusion probabilistic models.

Formally, Diffusion models are a class of generative models that mimic this diffusion and reversal process. They take an image, gradually add noise to it until it becomes complete noise, and then reverse the process to reconstruct the original image.

Forward and Backward Diffusion Process

The Forward and Reverse Diffusion Processes

The Forward and Reverse Diffusion Processes

At the heart of diffusion models lie two core processes:

  • The Forward Diffusion Process: Gradually adds noise to a data sample, such as an image, transforming it into pure noise.
  • The Reverse Diffusion Process: Reconstructs the original data by removing the added noise step by step.

Let’s dive deeper into both processes, including some math behind them.

The Forward Diffusion Process

The forward diffusion process models the gradual corruption of data using a Markov chain, where the probability of the next state depends only on the current state. Mathematically, this process is described as:

Forward Diffusion Process

The Forward Diffusion Process

q(xₜ | xₜ₋₁) = N(xₜ ; √(1- βₜ)xₜ₋₁, βₜI)

Where:

  • q is the transition probability
  • xₜ is the current state
  • xₜ₋₁ is the previous state
  • βₜ is the precision parameter
  • I is the identity matrix
  • ϵ is the normal function with mean 0 and variance I

Starting with a data sample x₀, the process continues until the sample becomes pure noise at time T.

q(xₜ | xₜ₋₁) = √(1-βₜ)xₜ₋₁ + √(βₜ)ϵ

At any arbitrary time step , the data distribution is:

xₜ = √(αₜ)xₜ₋₁ + √(1-αₜ)ϵ

Interesting Fact : One can compute the distribution of the data at time step t given the data distribution from any of the previous time steps

xₜ = √(ᾱₜ)x₀ + √(1-ᾱₜ)ϵ

where:

  • αₜ = 1 — βₜ
  • ᾱₜ = ∏(αᵢ), i=1 to t

Representation in the normal form:

mean => √(ᾱₜ)x₀

variance => from the expression xₜ = √(ᾱₜ)x₀ + √(1-ᾱₜ)ϵ, the variance of xₜ, conditioned on x₀ is:

Var(xₜ|x₀) = Var(√(ᾱₜ)x₀ + √(1-ᾱₜ)ϵ),

since x₀ is deterministic, the variance is contributed solely by the noise term:

Var(xₜ|x₀) = (1-ᾱₜ)Var(ϵ)

Var(xₜ|x₀) = (1-ᾱₜ)I

Finally,

q(xₜ | x₀) = N(xₜ ; √(ᾱₜ)x₀, (1-ᾱₜ)I) is the expression to find the data distribution at any time t given the initial distribution in the forward diffusion process

Code Snippet for the Forward Process:

def forward(x₀, t, betas = torch.linspace(0, 1, 5)):
    noise = torch.rand_like(x₀)
    alphas = 1 - betas
    alpha_hat = torch.cumprod(alphas, axis=0)
    alpha_hat_t = alpha_hat.gather(-1, t).reshape(-1, 1, 1, 1)

    mean = alpha_hat_t.sqrt() * x₀
    variance = torch.sqrt(1-alpha_hat_t) * noise
    xₜ = mean+variance

    return xₜ, noise

The Reverse Diffusion Process

To reverse the forward process, we need to gradually remove noise step by step. This process attempts to recover the original data by gradually denoising, predicting p(xₜ₋₁ | xₜ) with parameters θ. And is modeled as:

The Backward Diffusion Process

The Reverse Diffusion Process

p(xₜ₋₁|xₜ) = N(xₜ₋₁ | μθ(xₜ, t), σₜ²I)

where:

  • μθ(xₜ, t) is the predicted mean, parameterized by a neural network
  • σₜ² is the variance at time step t
  • θ represents the model parameters

The mean μθ(xₜ, t) can be parameterized as:

μθ(xₜ, t) = 1/√(αₜ) * (xₜ — (1-αₜ)/√(1-ᾱₜ) * εθ(xₜ, t))

where:

  • εθ(xₜ, t) is the predicted noise at time step t
  • αₜ and ᾱₜ are the same parameters from the forward process
  • The neural network learns to predict the noise component εθ(xₜ, t)

However, directly estimating from is challenging because the noise added during the forward process introduces ambiguity. To address this, the model conditions on the original data x₀.

It’s hard to predict which xₜ₋₁, xₜ came from

It’s hard to predict which xₜ₋₁, xₜ came from

Why do we need?

Imagine trying to reconstruct a blurred image. If the image was blurred 100 times, it’s nearly impossible to guess what it looked like at step 99 without knowing the original image. Similarly, in the reverse diffusion process, conditioning on helps the model estimate intermediate states more accurately.

It would be easier if we have x₀, to find xₜ₋₁ given xₜ

It would be easier if we have x₀, to find xₜ₋₁ given xₜ

Code Snippet for the Backward Process:

def backward(xₜ, t, model, betas=torch.linspace(0, 1, 5)):
    alphas = 1 - betas
    alphas_cumprod = torch.cumprod(alphas, dim=0)

    betas_t = get_index_from_list(betas, t, xₜ.shape)
    sqrt_one_minus_alphas_cumprod_t = get_index_from_list(torch.sqrt(1.0 - alphas_cumprod), t, xₜ.shape)
    sqrt_recip_alphas_t = get_index_from_list(torch.sqrt(1.0 / alphas), t, xₜ.shape)
 
    # predict noise
    predicted_noise = model(xₜ, t)

    # mean of the reverse process
    mean = sqrt_recip_alphas_t * (xₜ - betas_t * predicted_noise / sqrt_one_minus_alphas_cumprod_t) 
    
    # it is not the learned variance, but the variance of the noise
    posterior_variance_t = betas_t

    if torch.all(t > 0):
        noise = torch.randn_like(xₜ)
        variance = torch.sqrt(posterior_variance_t) * noise
        return mean + variance
    else:
        # No noise added at t=0
        return mean

def get_index_from_list(values, t, x_shape):
    batch_size = t.shape[0]
    result = values.gather(-1, t.cpu())
    return result.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)

Maximizing Likelihood

The primary objective in reverse diffusion is to maximize the likelihood(p(x₀)) of the model generating the original data. This is equivalent to minimizing the negative log-likelihood(-log(p(x))):

Using the properties of KL Divergence (which is always non-negative), we have:

-log(p(x₀)) ≤ -log(q(x₁:ₜ | x₀)) + KL(q(x₁:ₜ | x₀) || p(x₁:ₜ | x₀))

This inequality forms the basis for training diffusion models to reverse the noise addition process effectively. This decomposition ensures the model improves its reverse process to closely match the forward process, minimizing the reconstruction error.

Loss Function and Optimization

-log(p(x₀)) ≤ -log(q(x₁:ₜ | x₀)) + KL(q(x₁:ₜ | x₀) || p(x₁:ₜ | x₀)) — — — (1)

This RHS is called the Variational Lower Bound(VLB).

Taking the last term of equation (1)

KL(q(x₁:ₜ | x₀) || p(x₁:ₜ | x₀)) = log(q(x₁:ₜ | x₀) / p(x₁:ₜ | x₀)) — — — (2)

p(x₁:ₜ | x₀) = p(x₀ | x₁:ₜ)*p(x₁:ₜ)/p(x₀) — — (using Bayesian)

p(x₁:ₜ | x₀) = p(x₀:ₜ)/p(x₀)

Substituting back this in equation (2)

KL(q(x₁:ₜ | x₀) || p(x₁:ₜ | x₀)) = log(q(x₁:ₜ | x₀) /( p(x₀:ₜ)/p(x₀))))

KL(q(x₁:ₜ | x₀) || p(x₁:ₜ | x₀)) = log(q(x₁:ₜ | x₀) /p(x₀:ₜ)) + log(p(x₀))

Substituting back this in equation (1) and solving,

-log(p(x₀)) ≤ log(q(x₁:ₜ | x₀) /p(x₀:ₜ))

-log(p(x₀)) ≤ log[∏q(xₜ | xₜ₋₁)/p(xₜ)∏q(xₜ₋₁ | xₜ)]

-log(p(x₀)) ≤ -log(p(xₜ)) + log[∏q(xₜ | xₜ₋₁)/∏q(xₜ₋₁ | xₜ)]

-log(p(x₀)) ≤ -log(p(xₜ)) + ∑ₜ₌₂ᵀlog[q(xₜ | xₜ₋₁)/q(xₜ₋₁ | xₜ)] + log(q(x₁|x₀)/p(x₀|x₁)) — — (3)

As q(xₜ | xₜ₋₁) = q(xₜ₋₁ | xₜ)*q(xₜ) / q(xₜ₋₁)

As mentioned earlier we need x₀ to find xₜ₋₁ given xₜ, So here we go:

q(xₜ | xₜ₋₁) = q(xₜ₋₁ | xₜ, x₀)*q(xₜ|x₀) / q(xₜ₋₁|x₀)

substitute this back in equation (3)

-log(p(x₀)) ≤ -log(p(xₜ)) + ∑ₜ₌₂ᵀlog[q(xₜ₋₁ | xₜ, x₀)*q(xₜ | x₀)/p(xₜ₋₁ | xₜ)*q(xₜ₋₁ | x₀)] + log(q(x₁|x₀)/p(x₀|x₁))

-log(p(x₀)) ≤ -log(q(x₁|x₀)/p(xₜ)) + ∑ₜ₌₂ᵀlog[q(xₜ₋₁ | xₜ, x₀)/p(xₜ₋₁ | xₜ)] -log(p(x₀|x₁))

-log(p(x₀)) ≤ KL(q(xᴛ|x₀)||p(xᴛ)) + ∑ₜ₌₂ᵀKL(q(xₜ₋₁ | xₜ, x₀) || p(xₜ₋₁ | xₜ)) -log(p(x₀|x₁))

Simplifying the Loss Function

The loss function is typically separated into three terms:

  • Reconstruction Loss: Measures the error in reconstructing x₀ from x₁: => Eq(x₁:ᴛ|x₀)[−log pθ(x₀|x₁)]
  • KL Divergence Term: Quantifies the discrepancy between the forward diffusion q and the reverse process pθ​ at each time step t: => ∑ₜ₌₂ᵀEq(x₁:ᴛ|x₀)[KL(q(xₜ₋₁ | xₜ, x₀) || p(xₜ₋₁ | xₜ))]
  • Prior Matching Term: Encourages the final noisy state xₜ to match the prior distribution p(xₜ), usually a standard Gaussian: => KL(q(xᴛ|x₀)||p(xᴛ))

Loss Function in Practice

To simplify the training process, practitioners usually parameterize the reverse process p(xₜ₋₁ | xₜ) using neural networks and use a reparameterization trick to predict the added noise ϵ. The resulting simplified loss is:

L(simple) = Et,x₀,ϵ[||ϵ — ϵθ(xₜ,t)||²]n

Here the fun thing is, after deriving all the above long equations, finally we found a loss function which is just a square of difference between the actual noise added during forward diffusion (ϵ) and the predicted noise at time t, parameterized by the model (ϵθ(xₜ,t))

This loss ensures that the model learns to accurately predict the noise added at each time step, allowing for effective denoising in the reverse process.

This bridges the mathematical underpinnings and the practical implementation of diffusion probabilistic models.

Check out my code implementation:

GitHub repository https://github.com/arjuuuuunnnnn/Diffusion