Imagine dropping ink into water. As time passes, the ink gradually diffuses, spreading throughout the water until it becomes completely dispersed. Now, imagine reversing this process — starting with the dispersed ink and reconstructing the original drop. This is the fundamental intuition behind diffusion probabilistic models.
Formally, Diffusion models are a class of generative models that mimic this diffusion and reversal process. They take an image, gradually add noise to it until it becomes complete noise, and then reverse the process to reconstruct the original image.
The Forward and Reverse Diffusion Processes
At the heart of diffusion models lie two core processes:
Let’s dive deeper into both processes, including some math behind them.
The forward diffusion process models the gradual corruption of data using a Markov chain, where the probability of the next state depends only on the current state. Mathematically, this process is described as:
The Forward Diffusion Process
q(xₜ | xₜ₋₁) = N(xₜ ; √(1- βₜ)xₜ₋₁, βₜI)
Where:
Starting with a data sample x₀, the process continues until the sample becomes pure noise at time T.
q(xₜ | xₜ₋₁) = √(1-βₜ)xₜ₋₁ + √(βₜ)ϵ
At any arbitrary time step , the data distribution is:
xₜ = √(αₜ)xₜ₋₁ + √(1-αₜ)ϵ
Interesting Fact : One can compute the distribution of the data at time step t given the data distribution from any of the previous time steps
xₜ = √(ᾱₜ)x₀ + √(1-ᾱₜ)ϵ
where:
Representation in the normal form:
mean => √(ᾱₜ)x₀
variance => from the expression xₜ = √(ᾱₜ)x₀ + √(1-ᾱₜ)ϵ, the variance of xₜ, conditioned on x₀ is:
Var(xₜ|x₀) = Var(√(ᾱₜ)x₀ + √(1-ᾱₜ)ϵ),
since x₀ is deterministic, the variance is contributed solely by the noise term:
Var(xₜ|x₀) = (1-ᾱₜ)Var(ϵ)
Var(xₜ|x₀) = (1-ᾱₜ)I
Finally,
q(xₜ | x₀) = N(xₜ ; √(ᾱₜ)x₀, (1-ᾱₜ)I) is the expression to find the data distribution at any time t given the initial distribution in the forward diffusion process
Code Snippet for the Forward Process:
def forward(x₀, t, betas = torch.linspace(0, 1, 5)):
noise = torch.rand_like(x₀)
alphas = 1 - betas
alpha_hat = torch.cumprod(alphas, axis=0)
alpha_hat_t = alpha_hat.gather(-1, t).reshape(-1, 1, 1, 1)
mean = alpha_hat_t.sqrt() * x₀
variance = torch.sqrt(1-alpha_hat_t) * noise
xₜ = mean+variance
return xₜ, noise
To reverse the forward process, we need to gradually remove noise step by step. This process attempts to recover the original data by gradually denoising, predicting p(xₜ₋₁ | xₜ) with parameters θ. And is modeled as:
The Reverse Diffusion Process
p(xₜ₋₁|xₜ) = N(xₜ₋₁ | μθ(xₜ, t), σₜ²I)
where:
The mean μθ(xₜ, t) can be parameterized as:
μθ(xₜ, t) = 1/√(αₜ) * (xₜ — (1-αₜ)/√(1-ᾱₜ) * εθ(xₜ, t))
where:
However, directly estimating from is challenging because the noise added during the forward process introduces ambiguity. To address this, the model conditions on the original data x₀.
It’s hard to predict which xₜ₋₁, xₜ came from
Why do we need?
Imagine trying to reconstruct a blurred image. If the image was blurred 100 times, it’s nearly impossible to guess what it looked like at step 99 without knowing the original image. Similarly, in the reverse diffusion process, conditioning on helps the model estimate intermediate states more accurately.
It would be easier if we have x₀, to find xₜ₋₁ given xₜ
Code Snippet for the Backward Process:
def backward(xₜ, t, model, betas=torch.linspace(0, 1, 5)):
alphas = 1 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
betas_t = get_index_from_list(betas, t, xₜ.shape)
sqrt_one_minus_alphas_cumprod_t = get_index_from_list(torch.sqrt(1.0 - alphas_cumprod), t, xₜ.shape)
sqrt_recip_alphas_t = get_index_from_list(torch.sqrt(1.0 / alphas), t, xₜ.shape)
# predict noise
predicted_noise = model(xₜ, t)
# mean of the reverse process
mean = sqrt_recip_alphas_t * (xₜ - betas_t * predicted_noise / sqrt_one_minus_alphas_cumprod_t)
# it is not the learned variance, but the variance of the noise
posterior_variance_t = betas_t
if torch.all(t > 0):
noise = torch.randn_like(xₜ)
variance = torch.sqrt(posterior_variance_t) * noise
return mean + variance
else:
# No noise added at t=0
return mean
def get_index_from_list(values, t, x_shape):
batch_size = t.shape[0]
result = values.gather(-1, t.cpu())
return result.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)
The primary objective in reverse diffusion is to maximize the likelihood(p(x₀)) of the model generating the original data. This is equivalent to minimizing the negative log-likelihood(-log(p(x))):
Using the properties of KL Divergence (which is always non-negative), we have:
-log(p(x₀)) ≤ -log(q(x₁:ₜ | x₀)) + KL(q(x₁:ₜ | x₀) || p(x₁:ₜ | x₀))
This inequality forms the basis for training diffusion models to reverse the noise addition process effectively. This decomposition ensures the model improves its reverse process to closely match the forward process, minimizing the reconstruction error.
-log(p(x₀)) ≤ -log(q(x₁:ₜ | x₀)) + KL(q(x₁:ₜ | x₀) || p(x₁:ₜ | x₀)) — — — (1)
This RHS is called the Variational Lower Bound(VLB).
Taking the last term of equation (1)
KL(q(x₁:ₜ | x₀) || p(x₁:ₜ | x₀)) = log(q(x₁:ₜ | x₀) / p(x₁:ₜ | x₀)) — — — (2)
p(x₁:ₜ | x₀) = p(x₀ | x₁:ₜ)*p(x₁:ₜ)/p(x₀) — — (using Bayesian)
p(x₁:ₜ | x₀) = p(x₀:ₜ)/p(x₀)
Substituting back this in equation (2)
KL(q(x₁:ₜ | x₀) || p(x₁:ₜ | x₀)) = log(q(x₁:ₜ | x₀) /( p(x₀:ₜ)/p(x₀))))
KL(q(x₁:ₜ | x₀) || p(x₁:ₜ | x₀)) = log(q(x₁:ₜ | x₀) /p(x₀:ₜ)) + log(p(x₀))
Substituting back this in equation (1) and solving,
-log(p(x₀)) ≤ log(q(x₁:ₜ | x₀) /p(x₀:ₜ))
-log(p(x₀)) ≤ log[∏q(xₜ | xₜ₋₁)/p(xₜ)∏q(xₜ₋₁ | xₜ)]
-log(p(x₀)) ≤ -log(p(xₜ)) + log[∏q(xₜ | xₜ₋₁)/∏q(xₜ₋₁ | xₜ)]
-log(p(x₀)) ≤ -log(p(xₜ)) + ∑ₜ₌₂ᵀlog[q(xₜ | xₜ₋₁)/q(xₜ₋₁ | xₜ)] + log(q(x₁|x₀)/p(x₀|x₁)) — — (3)
As q(xₜ | xₜ₋₁) = q(xₜ₋₁ | xₜ)*q(xₜ) / q(xₜ₋₁)
As mentioned earlier we need x₀ to find xₜ₋₁ given xₜ, So here we go:
q(xₜ | xₜ₋₁) = q(xₜ₋₁ | xₜ, x₀)*q(xₜ|x₀) / q(xₜ₋₁|x₀)
substitute this back in equation (3)
-log(p(x₀)) ≤ -log(p(xₜ)) + ∑ₜ₌₂ᵀlog[q(xₜ₋₁ | xₜ, x₀)*q(xₜ | x₀)/p(xₜ₋₁ | xₜ)*q(xₜ₋₁ | x₀)] + log(q(x₁|x₀)/p(x₀|x₁))
-log(p(x₀)) ≤ -log(q(x₁|x₀)/p(xₜ)) + ∑ₜ₌₂ᵀlog[q(xₜ₋₁ | xₜ, x₀)/p(xₜ₋₁ | xₜ)] -log(p(x₀|x₁))
-log(p(x₀)) ≤ KL(q(xᴛ|x₀)||p(xᴛ)) + ∑ₜ₌₂ᵀKL(q(xₜ₋₁ | xₜ, x₀) || p(xₜ₋₁ | xₜ)) -log(p(x₀|x₁))
The loss function is typically separated into three terms:
To simplify the training process, practitioners usually parameterize the reverse process p(xₜ₋₁ | xₜ) using neural networks and use a reparameterization trick to predict the added noise ϵ. The resulting simplified loss is:
L(simple) = Et,x₀,ϵ[||ϵ — ϵθ(xₜ,t)||²]n
Here the fun thing is, after deriving all the above long equations, finally we found a loss function which is just a square of difference between the actual noise added during forward diffusion (ϵ) and the predicted noise at time t, parameterized by the model (ϵθ(xₜ,t))
This loss ensures that the model learns to accurately predict the noise added at each time step, allowing for effective denoising in the reverse process.
This bridges the mathematical underpinnings and the practical implementation of diffusion probabilistic models.
Check out my code implementation:
GitHub repository https://github.com/arjuuuuunnnnn/Diffusion