Denoising Diffusion Implicit Models
Problems with DDPMs:
Ideas:
Define a family \(\mathcal{Q}\) of (inference) distributions, indexed by vector \(\sigma\in\mathbb{R}^T_{\ge0}\):
\[q_\sigma(x_{1:T}\vert x_0):=q_\sigma(x_T\vert x_0)\prod_{t=2}^T q_\sigma(x_{t-1}\vert x_t,x_0)\]where
\[\begin{aligned} q_\sigma(x_T\vert x_0)&=\mathcal{N}(\sqrt{\alpha_T}x_0,(1-\alpha_T)I),\\ q_\sigma(x_{t-1}\vert x_t, x_0)&=\mathcal{N}(\sqrt{\alpha_{t-1}}x_0+\sqrt{1-\alpha_{t-1}-\sigma_t^2}\cdot\frac{x_t-\sqrt{\alpha_t}x_0}{\sqrt{1-\alpha_t}},\sigma_t^2I)\quad\text{for all } t>1. \end{aligned}\]Remarks:
Define a trainable generative process
\[p_\theta(x_{0:T}):=p_\theta(x_T)\prod_{i=1}^Tp_\theta^{(t)}(x_{t-1}\vert x_t),\]where each \(p_\theta^{(t)}(x_{t-1}\vert x_t)\) leverages knowledge of \(q_\sigma(x_{t-1}\vert x_t, x_0)\).
Make a prediction of the corresponding \(X_0\): The model \(\epsilon_\theta^{(t)}(x_t)\) predicts \(\epsilon_t\) from \(X_t\), without knowing \(X_0\). Then we can predict the denoised observation, which is a prediction of \(X_0\) given \(X_t\),
\[f_\theta^{(t)}(x_t):=\frac{1}{\sqrt{\alpha_t}}(x_t-\sqrt{1-\alpha_t}\epsilon_\theta^{(t)}(x_t)).\]Use the prediction to obtain to sample \(X_{t-1}\) from the reverse conditional distribution \(q_\sigma(x_{t-1}\vert x_t,x_0)\): we can define the generative process with a fixed prior \(p_\theta(x_T)=\mathcal{N}(0,I)\) and
\[p_\theta^{(t)}(x_{t-1}\vert x_t)=\left\{ \begin{aligned} &\mathcal{N}(f_\theta^{(1)}(x_1),\sigma_1^2 I)\quad&\text{if }t=1,\\ &q_\sigma(x_{t-1}\vert x_t,f_\theta^{(t)}(x_t))\quad&\text{if }t\ge1, \end{aligned} \right.\]where Gaussian noise is added to the case of \(t=1\) to ensure that the generative process is supported everywhere.
Remarks: This generative process is basically the same as DDPM with some minor differences
The parameters \(\theta\) are optimized via the variational inference objective
\[J_\sigma(\epsilon_\theta):=\mathbb{E}_{q_\sigma}[\log q_\sigma(X_{1:T}\vert X_0)-\log p_\theta(X_{0:T})]\]In comparison, DDPM optimizes the following objective:
\[L_\gamma(\epsilon_\theta):=\sum_{t=1}^T\gamma_t \mathbb{E}_{X_0,\epsilon_t}\left[\|\epsilon_\theta^{(t)}(\sqrt{\alpha_t}X_0+\sqrt{1-\alpha_t}\epsilon_t)-\epsilon_t\|_2^2\right],\]where \(\gamma\in\mathbb{R}^T_{>0}\) is a vector of positive coefficients in the objective that depends on \(\alpha_{1:T}\). In DDPM The objective with \(\gamma=1\) is optimized instead to maximize the generation performance of the trained model.
Theorem 1. For all \(\sigma\in\mathbb{R}^T_{>0}\), there exists \(\gamma\in\mathbb{R}^T_{>0}\) and \(C\in\mathbb{R}\), such that \(J_\sigma=L_\gamma+C\). (see Appendix for the proof)
Discussion:
With \(L_1\) as the objective (\(\sigma\) does not appear in the loss), we are not only learning a generative process for the Markovian inference process considered DDPM, but also generative processes for many non-Markovian forward processes parametrized by \(\sigma\) that described above.
Use pre-trained DDPM models as the solutions to the new objectives, and focus on finding a generative process that is better at producing samples subject to our needs by changing \(\sigma\).
Generate a sample \(x_{t-1}\) from a sample \(x_t\):
\[x_{t-1}= \underbrace{\sqrt{\alpha_{t-1}}\left(\frac{x_t-\sqrt{1-\alpha_t}\epsilon_\theta^{(t)}(x_t)}{\sqrt{\alpha_t}}\right)}_{\text{predicted }x_0} + \underbrace{\sqrt{1-\alpha_{t-1}-\sigma_t^2}\cdot\epsilon_\theta^{(t)}(x_t)}_{\text{direction pointing to }x_t} + \underbrace{\sigma_t\epsilon_t}_\text{random noise}\]where \(\epsilon_t\sim\mathcal{N}(0, I)\). Different choices of \(\sigma\) result in different generative processes, all while using the same model \(\epsilon_\theta\), so re-training the model is unnecessary.
The generative process is considered as the approximation to the reverse process, and therefore, they should have the same number of time steps \(T\).
However, as \(L_1\) does not depend on the specific forward procedure as long as \(q_\sigma(x_t\vert x_0)\) is fixed, we may also consider forward processes with lengths smaller than \(T\), which accelerates the corresponding generative processes without having to train a different model.
Details can be found in the Appendix
Insight:
The DDIM iterate (i.e., \(\sigma_t=0\)):
\[x_{t-1}= \sqrt{\alpha_{t-1}}\left(\frac{x_t-\sqrt{1-\alpha_t}\epsilon_\theta^{(t)}(x_t)}{\sqrt{\alpha_t}}\right) + \sqrt{1-\alpha_{t-1}}\cdot\epsilon_\theta^{(t)}(x_t)\]can be rewritten as
\[\frac{x_{t-\Delta t}}{\sqrt{\alpha_{t-\Delta t}}}=\frac{x_t}{\sqrt{\alpha_t}}+\left(\sqrt{\frac{1-\alpha_{t-\Delta t}}{\alpha_{t-\Delta t}}}-\sqrt{\frac{1-\alpha_t}{\alpha_t}}\right)\epsilon_\theta^{(t)}(x_t)\]We can reparameterize \(\sqrt{(1-\alpha)/\alpha}\) with \(\omega\) and \(x/\sqrt{\alpha}\) with \(\bar{x}\). When \(\Delta t\rightarrow 0\), \(\omega\) and \(\bar{x}\) are functions of \(t\), where \(\omega\) is continous, increasing with \(\omega(0)=0\). The above iteration can be treated as an Euler method over the following ODE:
\[\text{d}\bar{x}(t)=\epsilon^{(t)}_\theta\left(\frac{\bar{x}(t)}{\sqrt{\omega^2+1}}\right)\text{d}\omega(t),\]where the initial conditions is \(\bar{x}(T)=x(T)/\sqrt{\alpha(T)}\sim\mathcal{N}(0,1/\alpha(T))\). Since \(\alpha(T)\approx 0\), The variance \(1/\alpha(T)\) would be very large.
Key results:
Setup:
Vary the number of timesteps used to generate a sample (\(S=\text{dim}(\tau)\)) and the stochasticity of the process \(\eta\), and present a tradeoff between sample quality and computational costs.
Results:
For DDIM, the generative process is deterministic, and \(x_0\) would depend only on the initial state \(x_T\).
Compare generated images under different generative trajectories (i.e. different \(\tau\)) while starting with the same initial \(x_T\)
Results:
Since the high-level features of the DDIM sample are encoded by \(x_T\), it might be used for semantic interpolation.
This is different from the interpolation procedure in DDPM, where the same \(x_T\) would lead to highly diverse \(x_0\) due to the stochastic generative process
DDIM is able to control the generated images on a high level directly through the latent variables, which DDPMs cannot.
As DDIM is the Euler integration for a particular ODE, it should be able to encode from \(x_0\) to \(x_T\) (reverse of the ODE) and reconstruct \(x_0\) from the resulting \(x_T\) (forward of the ODE).
Results: DDIMs have lower reconstruction error for larger \(S\) and have properties similar to Neural ODEs and normalizing flows. The same cannot be said for DDPMs due to their stochastic nature.
meaning | DDPM |
DDIM (this work) |
---|---|---|
diffusion rate | \(\beta_t\) | \(1-\alpha_t/\alpha_{t-1}\) |
1-diffusion rate | \(\alpha_t\) | \(\alpha_t/\alpha_{t-1}\) |
product of 1-diffusion rate | \(\overline{\alpha}_t\) | \(\alpha_t\) |
The materials in this section are from Pattern Recognition and Machine Learning (Bishop, 2006) Section 2.3.3.
Given a marginal Gaussian distribution for \(x\) and a conditional Gaussian distribution for \(y\) given \(x\) in the form
\[\begin{aligned} p(x)&=\mathcal{N}(x;\mu,\Lambda^{-1})\\ p(y\vert x)&=\mathcal{N}(y; Ax+b,L^{-1}) \\ \end{aligned}\]The marginal distribution of \(y\) and the conditional distribution of \(x\) given \(y\) are given by
\[\begin{aligned} p(y) &= \mathcal{N}(y; A\mu+b, L^{-1}+A\Lambda^{-1}A^T)\\ p(x|y) &= \mathcal{N}(x; \Sigma\{A^TL(y-b)+\Lambda\mu\}, \Sigma) \end{aligned}\]where \(\Sigma=(\Lambda + A^TLA)^{-1}\).
The core of the inference distribution \(q_\sigma\) is the conditional distribution of \(X_{t-1}\) given \(X_t\) and \(X_0\), i.e.,
\[q_\sigma(x_{t-1}\vert x_t, x_0)=\mathcal{N}(\tilde{\mu}_t(x_t,x_0),\sigma_t^2I),\]where \(\tilde\mu_t\) is the mean function. Assuming it takes a linear form, i.e., \(\tilde\mu_t(x_t, x_0)=ax_t+bx_0\), where \(a\) and \(b\) are constants to be determined. We want the proposed joint distribution to match the “marginals” of the original DM. Specifically, suppose \(q_\sigma(x_t\vert x_0)=\mathcal{N}(\sqrt{\alpha_t} x_0, (1-\alpha_t)I)\), we want \(q_\sigma(x_{t-1}\vert x_0)=\mathcal{N}(\sqrt{\alpha_{t-1}} x_0, (1-\alpha_{t-1})I)\). We can compute \(q_\sigma(x_{t-1}\vert x_0)\) from \(q_\sigma(x_{t}\vert x_0)\) and \(q_\sigma(x_{t-1}\vert x_t,x_0)\) as follows. (see this section)
\[q_\sigma(x_{t-1}\vert x_0)=\mathcal{N}(a\sqrt{\alpha_t}x_0+bx_0,[\sigma_t^2+(1-\alpha_{t})a^2]I)\]We solve the following equations to match the mean and the variance: \(\begin{aligned} a\sqrt{\alpha_t}+b&=\sqrt{\alpha_{t-1}}\\ \sigma_t^2+(1-\alpha_{t})a^2&=1-\alpha_{t-1} \end{aligned}\)
which givens
\[\begin{aligned} a&=\frac{\sqrt{1-\alpha_{t-1}-\sigma_t^2}}{\sqrt{1-\alpha_{t}}}\\ b&=\sqrt{\alpha_{t-1}}-\frac{\sqrt{\alpha_t}\sqrt{1-\alpha_{t-1}-\sigma_t^2}}{\sqrt{1-\alpha_{t}}} \end{aligned}\]Therefore,
\[\tilde\mu_t=a x_t+bx_0=\sqrt{\alpha_{t-1}}x_0+\sqrt{1-\alpha_{t-1}-\sigma_t^2}\cdot\frac{x_t-\sqrt{\alpha_t}x_0}{\sqrt{1-\alpha_{t}}}\]and
\[q_\sigma(x_{t-1}\vert x_t, x_0)=\mathcal{N}(\sqrt{\alpha_{t-1}}x_0+\sqrt{1-\alpha_{t-1}-\sigma_t^2}\cdot\frac{x_t-\sqrt{\alpha_t}x_0}{\sqrt{1-\alpha_t}},\sigma_t^2I)\]In comparison, DDPM uses different mean and variance for \(q(x_{t-1}\vert x_t, x_0)\):
\[\begin{aligned} q(x_{t-1}\vert x_t, x_0)&=\mathcal{N}\left(\frac{\sqrt{\alpha_{t-1}}}{1-\alpha_{t}}\left(1-\frac{\alpha_t}{\alpha_{t-1}}\right)x_0 + \frac{\sqrt{\alpha_t}(1-\alpha_{t-1})}{\sqrt{\alpha_{t-1}}(1-\alpha_t)}x_t, \frac{1-\alpha_{t-1}}{1-\alpha_t}(1-\frac{\alpha_t}{\alpha_{t-1}})I\right) \\ &=\mathcal{N}\left(\sqrt{\alpha_{t-1}}x_0+\frac{\sqrt{\alpha_t}(1-\alpha_{t-1})}{\sqrt{\alpha_{t-1}}\sqrt{1-\alpha_t}}\cdot\frac{x_t-\sqrt{\alpha_t}x_0}{\sqrt{1-\alpha_t}},\frac{1-\alpha_{t-1}}{1-\alpha_t}(1-\frac{\alpha_t}{\alpha_{t-1}})I \right) \end{aligned}\]If we set
\[\sigma_t^2=\frac{1-\alpha_{t-1}}{1-\alpha_t}\left(1-\frac{\alpha_t}{\alpha_{t-1}}\right),\]then \(q_\sigma(x_{t-1}\vert x_t, x_0)=q(x_{t-1}\vert x_t, x_0)\) and the model becomes DDPM.
For all \(\sigma\in\mathbb{R}^T_{>0}\), there exists \(\gamma\in\mathbb{R}^T_{>0}\) and \(C\in\mathbb{R}\), such that \(J_\sigma=L_\gamma+C\).
Proof:
Following the derivation of DDPM (where \(\equiv\) denotes “equal up to a value that does not depend on \(\theta\), but may depend on \(q_\sigma\)”).
\[\begin{aligned} J_\sigma(\epsilon_\theta)&:=\mathbb{E}_{q_\sigma}[\log q_\sigma(X_{1:T}\vert X_0)-\log p_\theta(X_{0:T})] \\ &\equiv\mathbb{E}_{q_\sigma}\left[\sum_{t=2}^T D_\text{KL}(q_\sigma(x_{t-1}|X_t,X_0)||p_\theta^{(t)}(x_{t-1}|X_t)) -\log p_\theta^{1}(X_0|X_1)\right]\\ \end{aligned}\]For \(t>1\):
\[\begin{aligned} \mathbb{E}_{q_\sigma}\left[ D_\text{KL}(q_\sigma(x_{t-1}|X_t,X_0)||p_\theta^{(t)}(x_{t-1}|X_t))\right]&=\mathbb{E}_{X_0,X_t}\left[ D_\text{KL}(q_\sigma(x_{t-1}|X_t,X_0)||q_\sigma(x_{t-1}|X_t,f_\theta^{t}(X_t)))\right]\\ &\equiv\mathbb{E}_{X_0,X_t}\left[\frac{\|\tilde{\mu}_t(X_t,X_0)-\tilde{\mu}_t(X_t,f_\theta^{(t)}(X_t))\|_2^2}{2\sigma_t^2}\right]\\ &=\mathbb{E}_{X_0,X_t}\left[\frac{b_t^2}{2\sigma_t^2}\|X_0-f_\theta^{(t)}(X_t)\|_2^2\right]\\ &=\mathbb{E}_{X_0,\epsilon}\left[\frac{b_t^2(1-\alpha_t)}{2\sigma_t^2\alpha_t}\|\epsilon-\epsilon_\theta^{(t)}(X_t)\|_2^2\right]\\ \end{aligned}\]For \(t=1\):
\[\begin{aligned} \mathbb{E}_{q_\sigma}\left[ -\log p_\theta^{1}(X_0|X_1)\right]&\equiv\mathbb{E}_{X_0,X_t}\left[\frac{1}{2\sigma_t^2}\|X_0-f_\theta^{(t)}(X_1)\|_2^2\right]\\ &=\mathbb{E}_{X_0,\epsilon}\left[\frac{(1-\alpha_t)}{2\sigma_t^2\alpha_t}\|\epsilon-\epsilon_\theta^{(t)}(X_1)\|_2^2\right]\\ \end{aligned}\]Choosing \(\gamma_1=(1-\alpha_t)/(2\sigma_t^2\alpha_t)\) and \(\gamma_t=(1-\alpha_t)b_t^2/(2\sigma_t^2\alpha_t)\) for \(t>1\), we have \(J_\sigma(\epsilon_\theta)\equiv L_\gamma(\epsilon_\theta)\).
The inference process in the accelerated case is given by
\[q_{\sigma,\tau}(x_{1:T}\vert x_0)=q_{\sigma, \tau}(x_{\tau_S}\vert x_0)\prod_{i=2}^S q_{\sigma, \tau}(x_{\tau_{i-1}}\vert x_{\tau_i}, x_0)\prod_{t\in\overline\tau}q_{\sigma, \tau}(x_t|x_0),\]where \(\tau\) is a sub-sequence of \([1,\dots, T]\) of length \(S\) with \(\tau_S=T\), and \(\overline\tau:=\{1,\dots, T\}\backslash \tau\), i.e., the graphical model of \(\{X_{\tau_i}\}_{i=1}^S\) and \(X_0\) form a chain, whereas the graphical model of \(\{X_t\}_{t\in\overline\tau}\) and \(X_0\) form a star graph.
Define:
\[\begin{aligned} q_{\sigma,\tau}(x_t\vert x_0)&=\mathcal{N}(\sqrt{\alpha_t}x_0,(1-\alpha_t)I)\quad\forall t\in\overline\tau\cup\{T\}\\ q_{\sigma,\tau}(x_{\tau_{i-1}}\vert x_{\tau_i},x_0)&=\mathcal{N}(\sqrt{\alpha_{\tau_{i-1}}}x_0+\sqrt{1-\alpha_{\tau_{i-1}}-\sigma_{\tau_i}^2}\cdot\frac{x_{\tau_i}-\sqrt{\alpha_{\tau_i}}x_0}{\sqrt{1-\alpha_{\tau_i}}},\sigma_{\tau_i}^2I),\quad 2\le i\le S \end{aligned}\]where the coefficients are chosen such that:
\[q_{\sigma,\tau}(x_{\tau_i}|x_0)=\mathcal{N}(\sqrt{\alpha_{\tau_i}}x_0,(1-\alpha_{\tau_i})I)\quad 1\le i\le S,\]i.e., the “marginals” match.
The corresponding “generative process” is defined as:
\[p_\theta(x_{0:T}):= \underbrace{p_\theta(x_T)\prod_{i=1}^Sp_{\theta}^{(\tau_i)}(x_{\tau_{i-1}}\vert x_{\tau_i})}_\text{use to produce samples} \times \underbrace{\prod_{t\in\overline\tau}p_\theta^{(t)}(x_0\vert x_t)}_\text{use in objective},\]where only part of the models are actually being used to produce samples (define \(\tau_0=0\)). The conditionals are:
\[\begin{aligned} p_\theta^{\tau_{i}}(x_{\tau_{i-1}}\vert x_{\tau_i})&=q_{\sigma,\tau}(x_{\tau_{i-1}}\vert x_{\tau_i}, f_\theta^{(\tau_i)}(x_{\tau_i}))\quad\text{if }i\in\{2,\dots, S\}\\ p_\theta^{(t)}(x_0\vert x_t)&=\mathcal{N}(f_\theta^{(t)}(x_{t}),\sigma_{t}^2I),\quad\text{if }t\in\overline\tau\cup\{\tau_1\}, \end{aligned}\]which leverages \(q_{\sigma,\tau}(x_{\tau_{i-1}}\vert x_{\tau_i}, x_0)\) as part of the inference process.
The resulting variational objective becomes (define \(x_{\tau_{L+}}\))
\[\begin{aligned} J_{\sigma,\tau}(\epsilon_\theta)&=\mathbb{E}_{q_{\sigma,\tau}}[\log q_{\sigma,\tau}(X_{1:T}\vert X_0)-\log p_\theta(X_{0:T})]\\ &\equiv\mathbb{E}_{q_{\sigma,\tau}}\left[\sum_{i=2}^S D_\text{KL}(q_{\sigma,\tau}(x_{\tau_{i-1}}|X_{\tau_i},X_0)||p_\theta^{(\tau_i)}(x_{\tau_{i-1}}|X_{\tau_i})) -\log p_\theta^{(\tau_1)}(X_0|X_{\tau_i})\right.\\ &\qquad+\left. \sum_{t\in\overline\tau} -\log p_\theta^{(t)}(X_0|X_{t}) \right]\\ \end{aligned}\]A similar argument to the proof used in Theorem 1 can show that \(J_{\sigma,\tau}\) can also be converted to an objective of the form \(L_\gamma\).