\[ \def \vec#1{{\boldsymbol{#1}}} \def \mat#1{{\mathbf{#1}}} \def \argmax#1{\underset{#1}{\operatorname{argmax}}} \def \argmin#1{\underset{#1}{\operatorname{argmin}}} \]
\[ \def \vx{{\vec x}} \def \vtheta{{\vec \theta}} \def \mI{{\mat I}} \def \mZero{{\mat 0}} \def \mSigma{{\mat \Sigma}} \def \E{{\mathbb E}} \]
1 前言
擴散模型的出現讓我眼前一亮。它的原理與GAN截然不同,生成質量的天花板似乎更高,而且背後的原理啃起來也很費勁。
這次分享下我閲讀Understanding Diffusion Models: A Unified Perspective這篇文章的筆記。這篇文章嘗試從一個統一的視角分析擴散模型,介紹了擴散模型、VAE、基於分數的生成模型之間的聯繫。文章的數學推理很詳細,基本沒有跳步,很適合數學基礎一般的同學。
本文基本上是原論文的大致翻譯和概括。筆記難免存在一些信息上的簡略,感興趣的讀者可以閱讀原文。
2 背景
2.1 什麽是生成式模型
給定一個樣本\(\vec x\),生成式模型的目標是學習樣本的真實分佈\(p(\vec x)\).
這裏\(\vec x\)可以是圖像、語音、文本等。
一旦\(p(\vec x)\)被學習,我們就可以用\(p(\vec x)\)來生成新的樣本。
2.2 生成式模型的分類
- GAN(Generative Adversarial Networks)模型通過對抗的方式,使生成模型產生與真實樣本難以區分的樣本。GAN屬於隱式生成模型(implicit generative model),不直接建模數據的概率分布。
- 基於似然度的模型(likelihood-based model),通過最大化似然度來學習樣本的分佈。屬於顯式生成模型(explicit generative model)。常見方法包括:自回歸模型、VAE(Variational Autoencoder)模型、normalizing flow等。
- energy-based model,EBM,基於能量的模型。EBM定義一個能量函數,輸入為樣本\(\vec x\),輸出為一個標量能量值。模型的目標是讓真實樣本的能量值低、假樣本(或無關樣本)的能量值高。
- score-based generative model 使用神經網絡模型學習energy-based model的分數。
diffusion model既可以視爲likelihood-based,也可以視爲score-based。
2.3 柏拉圖的洞窟寓言
在生成式模型中,常常認爲\(\vec x\)是從隱變量\(\vec z\)生成的。
柏拉圖洞窟寓言中,一些奴隸被用鎖鏈拘束,終身囚禁于一個山洞中,面朝洞内。他們只能看見自己在洞壁的影子。那麽他們看到的二維影像(\(\vec x\)),就是從他們無法看見的三維事物(\(\vec z\))生成的。
類似的,在生成式模型中,樣本可能是從某個高級表示產生的。奴隸們雖然只能看見影子,但他們可以努力推理三維空間可能的樣子。我們也可以嘗試近似出我們觀測樣本的高級表示。
但是這個類比不恰當的地方在於,生成式模型通常從低維預測高維,學習高維樣本的低維表示(即一種壓縮),與洞窟寓言是相反的。這是因爲如果沒有很强的先驗,那麽從低维樣本學習高維表示將是徒勞的。
3 證據下界
樣本和隱變量構成聯合分佈\(p(\vec x, \vec z)\)。在基於似然度的方法中,我們希望用模型最大化所有觀測樣本\(\vec x\)的似然度\(p(\vec x)\)。
我們可以利用證據下界(Evidence Lower Bound,ELBO): \[ \mathbb E_{q_\phi(\vec z|\vec x)}\left[\log \frac{p(\vec x, \vec z)}{q_\phi(\vec z|\vec x)}\right] \tag{1}\] 我們稱\(\log p(\vec x)\)為證據(的大小),證據下界和證據的關係是: \[ \log p(\vec x) \geq \mathbb E _{q_\phi(\vec z|\vec x)} \left[\log \frac{p(\vec x, \vec z)}{q_\phi(\vec z| \vec x)}\right] \] 其中\(q_\phi(\vec z| \vec x)\)就是需要我們優化的將觀測樣本\(\vec x\)映射到隱變量\(\vec z\)的編碼器,是對真實後驗\(p(\vec z|\vec x)\)的近似,而\(\phi\)是其參數。
後面我們將會看到,我們通過優化參數\(\phi\),最大化公式 1,就能獲得取得真實樣本的數據分佈,并能從中采樣。
現在讓我們理清爲什麽要最大化ELBO。 \[ \begin{aligned} \log p(\vec x) &= \log \int p(\vec x, \vec z) d\vec z & \text{對聯合分佈的邊緣化} \\ &= \log \int \frac{p(\vec x, \vec z)q_\phi(\vec z|\vec x)}{q_\phi(\vec z|\vec x)} dz & \text{分子分母乘以同一個數} \\ &= \log \mathbb E _{q_\phi(\vec z|\vec x)} \left[\frac{p(\vec x, \vec z)}{q_\phi(\vec z|\vec x)}\right] & \text{期望的定義} \\ &\geq \mathbb E _{q_\phi(\vec z|\vec x)}\left[ \log \frac{p(\vec x, \vec z)}{q_\phi(\vec z|\vec x)} \right] & \text{Jensen不等式} \end{aligned} \tag{2}\] 以上是一種利用Jensen不等式的推導方式。還有一種證明方式,稍顯冗長,但是能提供更多爲何使用ELBO的直覺。
\[ \begin{aligned} \log p(\vec x) &= \log p(\vec x) \int q_\phi(\vec z|\vec x) dz & \text{乘以}1 = \int q_\phi(\vec z| \vec x)d\vec z\\ &= \int q_\phi(\vec z|\vec x)(\log p(\vec x)) d\vec z & \log p(\vec x)\text{移到積分符號後} \\ &= \mathbb E _{q_\phi(\vec z|\vec x)} \left[\log p(\vec x)\right] & \text{期望的定義}\\ &= \mathbb E _{q_\phi(\vec z|\vec x)}\left[\log\frac{p(\vec x, \vec z)}{p(\vec z|\vec x)}\right] & \text{鏈式法則}\\ &= \mathbb E _{q_\phi(\vec z|\vec x)} \left[\log\frac{p(\vec x, \vec z)q_\phi(\vec z|\vec x)}{p(\vec z|\vec x)q_\phi(\vec z|\vec x)}\right] & \text{分子分母同乘以一個數}\\ &= \mathbb E _{q_\phi(\vec z|\vec x)} \left[\log\frac{p(\vec x,\vec z)}{q_\phi(\vec z|\vec x)}\right] + \mathbb E_{q_\phi(\vec z|\vec x)}\left[\log\frac{q_\phi(\vec z|\vec x)}{p(\vec z|\vec x)}\right] & \text{期望的拆分}\\ &= \mathbb E _{q_\phi(\vec z|\vec x)} \left[\log \frac{p(\vec x, \vec z)}{q_\phi(\vec z|\vec x)}\right] + D_\text{KL} (q_\phi(\vec z|\vec x)\Vert p(\vec z|\vec x)) & \text{KL散度的定義} \\ &\geq \mathbb E _{q_\phi(\vec z|\vec x)} \left[\log \frac{p(\vec x, \vec z)}{q_\phi(\vec z|\vec x)}\right] & \text{KL散度非負} \end{aligned} \tag{3}\]
雖然公式 2和公式 3都證明了ELBO是證據下界,但是對於公式 3的理解是更關鍵的,即:
- 兩者的差值恰好為\(q_\phi(\vec z|\vec x)\)和\(p(\vec z|\vec x)\)的KL散度
- 公式 3的左邊實際上是一個常數,這個常數等於ELBO與KL散度的和。
- 進一步的,最大化ELBO的過程也就是最小化\(q_\phi(\vec z|\vec x)\)和\(p(\vec z|\vec x)\)的KL散度的過程。由於我們不知道\(p(\vec z|\vec x)\)的真值,所以我們無法直接最小化KL散度,但是ELBO允許我們間接實現這項優化。
- 一旦訓練完畢,ELBO可以用於估計生成樣本的似然度。因爲ELBO的優化目標是逼近\(\log p(\vec x)\)。
4 變分自動編碼器
變分自動編碼器(Variational Autoencoder,VAE)的名字中有“變分”兩個字,是因爲它的目標是從所有可能的後驗分佈中尋找最優的\(q_phi(\vec z|\vec x)\)。它被稱爲編碼器,是因爲它具備傳統編碼器模型的特質,即它嘗試將輸入數據壓縮為低維向量然後試圖再還原為原來的輸入。
VAE的目標是直接最大化ELBO。爲了更清楚地分析VAE,我們繼續拆解分析ELBO: \[ \begin{aligned} \mathbb E_{q_\phi(\vec z|\vec x)} \left[\log\frac{p(\vec x, \vec z)}{q_\phi(\vec z|\vec x)}\right] &= \mathbb E_{q_\phi(\vec z|\vec x)} \left[\log\frac{p_\theta(\vec x|\vec z)p(\vec z)}{q_\phi(\vec z|\vec x)}\right] & \text{鏈式法則}\\ &= \mathbb E_{q_\phi(\vec z|\vec x)}\left[\log p_\theta(\vec x|\vec z)\right] + \mathbb E_{q_\phi(\vec z|\vec x)}\left[\log \frac{p(\vec z)}{q_\phi(\vec z|\vec x)}\right] & \text{期望的拆分}\\ &= \underbrace{\mathbb E_{q_\phi(\vec z|\vec x)}\left[\log p_\theta(\vec x|\vec z)\right]}_{重構項} - \underbrace{D_\text{KL}(q_\phi(\vec z|\vec x)\Vert p(\vec z))}_{先驗匹配項} & \text{KL散度的定義} \end{aligned} \tag{4}\] 在這個公式中,\(q_\phi(\vec z|\vec x)\)被稱爲編碼器(encoder),\(p_\theta(\vec x|\vec z)\)被稱爲解碼器(decoder)。
公式 4被拆分成兩項,其中重構項度量了學習到的樣本分佈是否建模了真正的分佈,而先驗匹配項學習到的隱變量是否服從預設的先驗分佈。
在VAE模型中,encoder通常選擇將隱變量設置爲服從具有對角協方差矩陣的多元高斯分佈,將先驗分佈設置爲標準正態分佈: \[ \begin{aligned} q_\phi(\vec z|\vec x) &= \mathcal N(\vec z; \vec \mu_{\vec\phi}(\vec x), \vec \sigma^2_{\vec \phi}(\vec x) \vec I)\\ p(\vec z) &= \mathcal N(\vec z; \vec 0, \vec I) \end{aligned} \]
在這樣的設置下,先驗匹配項中的KL散度是可以解析計算的,重構項可以通過蒙特卡洛估計方法近似。優化目標可以重寫為: \[ \argmax{\vec \phi, \vec \theta}{\mathbb E _{q_\phi(\vec z|\vec x)}}\left[\log p_{\vec \theta}(\vec x|\vec z)\right] - D_{KL}(q_{\vec \phi}(\vec z|\vec x)\Vert p(\vec z)) \approx \argmax{\vec \phi, \vec \theta} \sum_{l=1}^L \log p_{\vec \theta}(\vec x|\vec z^{(l)}) - D_{KL}(q_{\vec \phi}(\vec z|\vec x) \Vert p(\vec z)) \]
其中\(\left\{\vec z^{(l)}\right\}_{l=1}^L\)是對於每個數據集中的樣本\(\vec x\)從\(q_\phi(\vec z|\vec x)\)中采樣的\(L\)個隱變量。這其中隨機采樣的過程通常是不可微的,幸運的是我們可以采用以下的“重參數化”技巧: \[ \vec z = \vec \mu_\pi(\vec x) + \vec \sigma_\phi(\vec x) \odot \vec \epsilon, \text{with} \vec\epsilon \sim \mathcal N(\vec \epsilon; \vec 0, \vec I) \] 其中\(\odot\)表示逐元素乘積。得益於重參數化方法,我們可以實現損失函數對\(\vec \phi\)和\(\vec \theta\)的求導。
5 分層的變分自動編碼器
分層變分自動編碼器(Hierarchical Variational Autoencoder,HVAE)將VAE推廣到具有多層隱變量的情形。每一層的隱變量都以前一層隱變量為條件生成。論文討論了HVAE的特例,馬爾科夫HVAE(Markovian HVAE,MHVAE)。在MHVAE中,\(\vec z_t\)的生成只依賴\(\vec z_{t+1}\),而不用考慮\(\vec z_{t+2}\)。MHVAE的聯合分佈和後驗分佈是: \[ p(\vec x, \vec z_{1:T})=p(\vec z_T)p_{\vec \theta}(\vec x|\vec z_1)\prod_{t=2}^T p_{\vec \theta}(\vec z_{t-1}|\vec z_t) \tag{5}\] \[ q_{\vec \phi} (\vec z_{1:T}\vert \vec x)=q_\phi(\vec z_1|\vec x)\prod_{t=2}^T q_{\vec \phi}(\vec z_t\vert z_{t-1}) \tag{6}\] 這時ELBO公式可以寫為 \[ \begin{aligned} \log p(\vec x) &= \log \int p(\vec x, \vec z_{1:T}) d \vec z_{1:T} & \text{對聯合分佈的邊緣化}\\ &= \log \int \frac{p(\vec x, \vec z_{1:T})q_{\vec \phi}(\vec z_{1:T}\vert \vec x)}{q_{\vec \phi}(\vec z_{1:T}\vert \vec x)} d\vec z_{1:T} & \text{分子分母乘以同一個數} \\ &= \log \mathbb E_{q_{\vec \phi}(\vec z_{1:T}\vert \vec x)} \left[\frac{p(\vec x, \vec z_{1:T})}{q_{\vec \phi}(\vec z_{1:T}\vert \vec x)}\right] &\text{期望的定義} \\ &\geq \mathbb E _{q_{\vec \phi}(\vec z_{1:T}\vert \vec x)} \left[\log \frac{p(\vec x, \vec z_{1:T})}{q_{\vec \phi}(\vec z_{1:T} \vert \vec x)}\right] & \text{Jensen不等式} \end{aligned} \] 然後再將聯合分佈公式 5和後驗分佈公式 6代入得到 \[ \mathbb E _{q_{\vec \phi}(\vec z_{1:T}\vert \vec x)} \left[\log \frac{p(\vec x, \vec z_{1:T})}{q_{\vec \phi}(\vec z_{1:T} \vert \vec x)}\right] = \mathbb E_{q_{\vec \phi}(\vec z_{q:T}\vert \vec x)} \left[ \log \frac{ p(\vec z_T)p_{\vec \theta}(\vec x|\vec z_1)\prod_{t=2}^T p_{\vec \theta}(\vec z_{t-1}|\vec z_t) }{ q_\phi(\vec z_1|\vec x)\prod_{t=2}^T q_{\vec \phi}(\vec z_t\vert z_{t-1}) } \right] \]
後面我們將會看到,在討論變分擴散模型時,這個目標函數將可以分解為更多可解釋的部件。
6 變分擴散模型
變分擴散模型(Variational Diffusion Model,VDM)可以視為施加了以下三個條件的MHVAE:
- 隱變量的維度大小總是和樣本數據的維度大小一樣
- 每一步的encoder不是可學習的,而是預定義的線性高斯模型
- 每一步的encoder的參數隨著步驟\(t\)變化,使得最後一層的隱變量最終服從標準正態分佈。
因為有第一條限制,所以我們可以引入符號\(\vec x_t\)同時表示隱變量和原始數據。當\(t=0\)時,它代表原始數據,當\(t\in[1, T]\)時,它代表第\(t\)層的隱變量。
於是VDM的後驗分佈可以重寫為: \[ q(\vec x_{1:T}\vert \vec x_0) = \prod_{t=1}^Tq(\vec x_t \vert x_{t-1}) \] 與普通的MHVAE不同,VDM的encoder不是可學習的(條件2),而是被通常被定義為均值\(\mu_t(\vec x_t)=\sqrt{\alpha_t} \vec x_{t-1}\),協方差矩陣\(\Sigma_t(\vec x_t) = (1 - \alpha_t)\mat I\). 這裡\(\alpha_t\)是隨層級\(t\)變化的參數,可以是預設的,或者是可學習的。
編碼器的模型可以記為: \[ q(\vec x_t|\vec x_{t-1}) =\mathcal N(\vec x_t;\sqrt{a_t}\vec x_{t-1}, (1- \alpha_t) \mat I) \] 根據條件3,經過若干層這樣的編碼器,最終的\(p(\vec x_T)\)將會服從標準正態分佈。於是MHVAE的聯合分佈公式 5變為 \[ p(\vec x_{0:T}) = p(\vec x_T)\prod_{t=1}^T p_{\vec \theta}(\vec x_{t-1}|\vec x_t) \] 其中 \[ p(\vec x_T) = \mathcal N(\vec x_T; \mat 0, \mat I) \] 總的來說,VDM假設的幾個條件要求每一步“編碼”都是持續增加噪聲,直到數據成為完全的高斯噪聲的過程。
在VDM中,\(q(\vec x_t|\vec x_{t-1})\)是完全固定的,因此我們關係的就只剩\(p_{\vec \theta}(\vec x_{t-1}|\vec x_t)\)的學習了。
如果你已經完成了一個VDM的訓練,那麼從模型中採樣圖片的過程就是從\(p(\vec x_T)\)中隨機採樣一個數據,然後迭代地運行\(p_{\vec \theta}(\vec x_{t-1}|\vec x_t)\),直到生成\(\vec x_0\).
與其它HVAE相似,VDM同樣可以用ELBO作為優化目標: \[ \begin{aligned} \log p(\vec x) &= \log \int p(\vec x_{0:T} d\vec x_{1:T}) &\text{對聯合分佈的邊緣化}\\ &= \log \int \frac{p(\vec x_{0:T})q(\vec x_{1:T}|\vec x_0)}{q(\vec x_{1:T}|\vec x_0)} d\vec x_{1:T} & \text{分子分母乘以同一個數} \\ &= \log \E _{q(\vec x_{1:T}|\vec x_0)} \left[\frac{p(\vec x_{0:T})}{q(\vec x_{1:T}|\vec x_0)}\right] &\text{期望的定義}\\ &\geq \mathbb E_{q(\vec x_{1:T}|\vec x_0)}\left[\log \frac{p(\vec x_{0:T})}{q(\vec x_{1:T}|\vec x_0)}\right] & \text{Jensen不等式}\\ &= \mathbb E_{q(\vec x_{1:T}|\vec x_0)}\left[\log\frac{p(\vec x_T)\prod_{t=1}^Tp_{\vec \theta}(\vec x_{t-1}|\vec x_t)}{\prod_{t=1}^T q(\vec x_t|\vec x_{t-1})}\right] & \text{鏈式法則、馬爾科夫性質}\\ &= \mathbb E_{q(\vec x_{1:T}|\vec x_0)} \left[\log\frac{p(\vec x_T)p_{\vec \theta}(\vec x_0|\vec x_1)\prod_{t=2}^Tp_{\vec\theta}(\vec x_{t-1}|\vec x_t)}{q(\vec x_T|\vec x_{T-1})\prod_{t=1}^{T-1}q(\vec x_t|\vec x_{t-1})}\right] \\ &= \mathbb E_{q(\vec x_{1:T}|\vec x_0)} \left[\log\frac{p(\vec x_T)p_{\vec \theta}(\vec x_0|\vec x_1)\prod_{t=1}^{T-1}p_{\vec\theta}(\vec x_t|\vec x_{t+1})}{q(\vec x_T|\vec x_{T-1})\prod_{t=1}^{T-1}q(\vec x_t|\vec x_{t-1})}\right] \\ &= \mathbb E_{q(\vec x_{1:T}|\vec x_0)} \left[\log\frac{p(\vec x_T)p_{\vec \theta}(\vec x_0|\vec x_1)}{q(\vec x_T|\vec x_{T-1})}\right]+ \mathbb E_{q(\vec x_{1:T}|\vec x_0)} \left[\log\prod_{t=1}^{T-1}\frac{p_{\vec \theta}(\vec x_t|\vec x_{t+1})}{q(\vec x_t|\vec x_{t-1})}\right] & \text{數學期望有線性性質} \\ &= \mathbb E_{q(\vec x_{1:T}|\vec x_0)} \left[\log p_{\vec\theta}(\vec x_0|\vec x_1)\right]+ \mathbb E_{q(\vec x_{1:T}|\vec x_0)} \left[\log\frac{p(\vec x_T)}{q(\vec x_T | \vec x_{T-1})}\right] + \mathbb E_{q(\vec x_{1:T}|\vec x_0)} \left[\sum_{t=1}^{T-1}\log \frac{p_{\vec\theta}(\vec x_t|\vec x_{t+1})}{q(\vec x_t|\vec x_{t-1})}\right] \\ &= \mathbb E_{q(\vec x_{1:T}|\vec x_0)} \left[\log p_{\vec\theta}(\vec x_0|\vec x_1)\right]+ \mathbb E_{q(\vec x_{1:T}|\vec x_0)} \left[\log\frac{p(\vec x_T)}{q(\vec x_T | \vec x_{T-1})}\right] + \sum_{t=1}^{T-1}\mathbb E_{q(\vec x_{1:T}|\vec x_0)}\left[\log\frac{p_\vtheta(\vec x_t|\vec x_{t+1})}{q(\vx_t|\vx_{t-1})}\right] \\ &= \E_{q(\vx_1|\vx_0)} \left[\log p_{\vec\theta}(\vec x_0|\vec x_1)\right] + \E_{q(\vx_{T-1},\vx_T|\vx_0)}\left[\log\frac{p(\vec x_T)}{q(\vec x_T | \vec x_{T-1})}\right] + \sum_{t=1}^{T-1}\E_{q(\vx_{t-1}, \vx_t,\vx_{t+1}|\vx_0)} \left[\log\frac{p_\vtheta(\vec x_t|\vec x_{t+1})}{q(\vx_t|\vx_{t-1})}\right] \\ & \text{按數學期望的定義展開} \\ &= \underbrace{\E_{q(\vx_1|\vx_0)} \left[\log p_{\vec\theta}(\vec x_0|\vec x_1)\right]}_{重構項} + \int \int \log \frac{p(\vx_T)}{q(\vx_T|\vx_{T-1})} q(\vx_{T-1}, \vx_T|\vx_0)d\vx_{T}d\vx_{T-1} + \sum_{t=1}^{T-1}\E_{q(\vx_{t-1}, \vx_t,\vx_{t+1}|\vx_0)} \left[\log\frac{p_\vtheta(\vec x_t|\vec x_{t+1})}{q(\vx_t|\vx_{t-1})}\right] \\ &= \underbrace{\E_{q(\vx_1|\vx_0)} \left[\log p_{\vec\theta}(\vec x_0|\vec x_1)\right]}_{重構項} + \int \left( \int \log \frac{p(\vx_T)}{q(\vx_T|\vx_{T-1})} q(\vx_{T}|\vx_{T-1})d\vx_{T}\right)q(\vx_{T-1}|\vx_0)d\vx_{T-1} + \sum_{t=1}^{T-1}\E_{q(\vx_{t-1}, \vx_t,\vx_{t+1}|\vx_0)} \left[\log\frac{p_\vtheta(\vec x_t|\vec x_{t+1})}{q(\vx_t|\vx_{t-1})}\right] \\ & \text{應用KL散度的定義} \\ &= \underbrace{\E_{q(\vx_1|\vx_0)} \left[\log p_{\vec\theta}(\vec x_0|\vec x_1)\right]}_{重構項} + \int \Big(- D_\text{KL}\left(q(\vx_T\vert \vx_{T-1})\Vert p(\vx_T)\right)\Big)q(\vx_{T-1}|\vx_0)d\vx_{T-1} + \sum_{t=1}^{T-1}\E_{q(\vx_{t-1}, \vx_t,\vx_{t+1}|\vx_0)} \left[\log\frac{p_\vtheta(\vec x_t|\vec x_{t+1})}{q(\vx_t|\vx_{t-1})}\right] \\ & \text{期望的定義} \\ &= \underbrace{\E_{q(\vx_1|\vx_0)} \left[\log p_{\vec\theta}(\vec x_0|\vec x_1)\right]}_{重構項} - \mathbb E_{q(\vx_{T-1}|\vx_0)} \left[ D_\text{KL}\left(q(\vx_T\vert \vx_{T-1})\Vert p(\vx_T)\right)\right] + \sum_{t=1}^{T-1}\E_{q(\vx_{t-1}, \vx_t,\vx_{t+1}|\vx_0)} \left[\log\frac{p_\vtheta(\vec x_t|\vec x_{t+1})}{q(\vx_t|\vx_{t-1})}\right] \\ & \text{對於最後一項也是同理}\\ &= \underbrace{\E_{q(\vx_1|\vx_0)} \left[\log p_{\vec\theta}(\vec x_0|\vec x_1)\right]}_{重構項} - \underbrace{\mathbb E_{q(\vx_{T-1}|\vx_0)} \left[ D_\text{KL}\left(q(\vx_T\vert \vx_{T-1})\Vert p(\vx_T)\right)\right]}_{先驗匹配項} - \sum_{t=1}^{T-1}\underbrace{\E_{q(\vx_{t-1},\vx_{t+1}|\vx_0)} \left[D_\text{KL}(q(\vx_t|\vx_{t-1})\Vert p_\vtheta (\vx_t|\vx_{t+1}))\right]}_{一致性約束項} \\ \end{aligned} \]
可以看到ELBO可以拆解為三項:
- 重構項要求用第一層隱變量恢復原始數據的似然度最大化;
- 先驗匹配項要求最終的隱變量接近高斯先驗分佈。這一項其實不會體現在損失函數裡,因為這一步沒有任何可訓練參數。實踐中我們會通過一些設計和參數選擇,使得最後一層隱變量接近高斯先驗;
- 一致性約束項要求隱變量在前向加噪和逆向去噪的兩個方向產生的隱變量分佈一致。
那麼VDM下,ELBO的所有項都可以用蒙特卡洛採樣的方式優化。但是注意到最後一項涉及到兩個隨機變量的採樣\(\{\vx_{t-1}, \vx_{t+1}\}\),這可能導致估計結果的方差偏大。為了優化這個問題,我們可以考慮重新推導ELBO,寫出一個一次只需採樣一個隨機變量的版本。
注意到由於馬爾科夫鏈的性質,有 \[ q(\vx_t|\vx_{t-1}, \vx_0) = \frac{ q(\vx_{t-1}|\vx_t,\vx_0) q(\vx_t|\vx_0) } { q(\vx_{t-1}|\vx_0) }, \] 利用這個式子,我們可以重新推導ELBO: \[ \begin{aligned} \log p(\vx) &\geq \E _{q(\vx_{1:T}|\vx_0)}\left[\log\frac{p(\vx_{0:T})}{q(\vx_{1:T}|\vx_0)}\right] \\ &= \mathbb E_{q(\vec x_{1:T}|\vec x_0)}\left[\log\frac{p(\vec x_T)\prod_{t=1}^Tp_{\vec \theta}(\vec x_{t-1}|\vec x_t)}{\prod_{t=1}^T q(\vec x_t|\vec x_{t-1})}\right] & \text{鏈式法則、馬爾科夫性質}\\ &= \mathbb E_{q(\vec x_{1:T}|\vec x_0)}\left[\log \frac{p(\vx_T)p_\vtheta (\vx_0|\vx_1)\prod_{t=2}^T p_\vtheta (\vx_{t-1}|\vx_t)}{q(\vx_1|\vx_0)\prod_{t=2}^Tq(\vx_t|\vx_{t-1})}\right] \\ &= \mathbb E_{q(\vec x_{1:T}|\vec x_0)}\left[\log \frac{p(\vx_T)p_\vtheta (\vx_0|\vx_1)\prod_{t=2}^T p_\vtheta (\vx_{t-1}|\vx_t)}{q(\vx_1|\vx_0)\prod_{t=2}^Tq(\vx_t|\vx_{t-1}, \vx_0)}\right] & \text{應用馬爾科夫鏈的性質} \\ &= \mathbb E_{q(\vec x_{1:T}|\vec x_0)}\left[\log \frac{p(\vx_T)p_\vtheta (\vx_0|\vx_1)}{q(\vx_1|\vx_0)} + \log \prod_{t=2}^T\frac{ p_\vtheta (\vx_{t-1}|\vx_t)}{q(\vx_t|\vx_{t-1}, \vx_0)}\right] \\ &= \mathbb E_{q(\vec x_{1:T}|\vec x_0)}\left[\log \frac{p(\vx_T)p_\vtheta (\vx_0|\vx_1)}{q(\vx_1|\vx_0)} + \log \prod_{t=2}^T\frac{ p_\vtheta (\vx_{t-1}|\vx_t)}{\frac{q(\vx_{t-1}|\vx_t,\vx_0)\textcolor{red}{q(\vx_t|\vx_0)}}{\textcolor{red}{q(\vx_{t-1}|\vx_0)}}}\right] & \text{貝葉斯定理}\\ & \text{紅色部分在連乘的相鄰項間被抵消} \\ &=\E _{q(\vx_{1:T}|\vx_0)}\left[\log \frac{p(\vx_T)p_\vtheta(\vx_0|\vx_1)}{\cancel{q(\vx_1|\vx_0)}} + \log \frac{\cancel{q(\vx_1|\vx_0)}}{q(\vx_T|\vx_0)}+\log\prod_{t=2}^T \frac{p_\vtheta(\vx_{t-1}|\vx_t)}{q(\vx_{t-1}|\vx_t, \vx_0)}\right] \\ &=\E _{q(\vx_{1:T}|\vx_0)}\left[\log \frac{p(\vx_T)p_\vtheta(\vx_0|\vx_1)}{q(\vx_T|\vx_0)}+\sum_{t=2}^T \log \frac{p_\vtheta(\vx_{t-1}|\vx_t)}{q(\vx_{t-1}|\vx_t, \vx_0)}\right] \\ &=\E _{q(\vx_{1:T}|\vx_0)}[\log p_\vtheta(\vx_0|\vx_1)] + \E _{q(\vx_{1:T}|\vx_0)}\left[ \log\frac{p(\vx_T)}{q(\vx_T|\vx_0)} \right] + \sum_{t=2}^T\E _{q(\vx_{1:T}|\vx_0)} \left[\log\frac{p_\vtheta(\vx_{t-1}|\vx_t)}{q(\vx_{t-1}|\vx_t, \vx_0)}\right] \\ &=\E_{q(\vx_1|\vx_0)}[\log p_\vtheta(\vx_0|\vx_1)] + \E_{q(\vx_T|\vx_0)}\left[\log\frac{p(\vx_T)}{q(\vx_T|\vx_0)}\right] + \sum_{t=2}^T \E_{q(\vx_t, \vx_{t-1}|\vx_0)}\left[\log\frac{p_\vtheta(\vx_{t-1}|\vx_t)}{q(\vx_{t-1}|\vx_t, \vx_0)}\right] \\ &=\underbrace{\E_{q(\vx_1|\vx_0)}[\log p_\vtheta(\vx_0|\vx_1)]}_{重構項} - \underbrace{D_\text{KL}(q(\vx_T|\vx_0)\Vert p(\vx_T))}_{先驗匹配項} - \sum_{t=2}^T \underbrace{\E_{q(\vx_t|\vx_0)}[D_\text{KL}(q(\vx_{t-1} | \vx_t, \vx_0)\Vert p_\vtheta (\vx_{t-1}|\vx_t))]}_{去噪匹配項} \end{aligned} \tag{7}\] 經過這樣的推導,我們將ELBO拆解為3項:
- 重構項:可以用蒙特卡洛方法採樣估計
- 先驗匹配項:不包含可訓練參數。由於VDM的假設,這一項的值為0
- 去噪匹配項:在這一項中,\(q(\vx_{t-1}|\vx_t, \vx_0)\)為ground-truth信號,它定義了已知真實樣本\(\vx_0\)時,從\(\vx_t\)到\(\vx_{t-1}\)的去噪過程是怎麼樣的。我們用\(p_\vtheta(\vx_{t-1}|\vx_t)\)去近似它。
注意到由於我們沒有應用馬爾科夫性質以外的假設,所以其實以上推導對於任意MHVAE都是適用的。也適用於\(T=1\)時的MHVAE,這時MHVAE退化成普通的VAE。
ELBO中,\(D_\text{KL}(q(\vx_{t-1} | \vx_t, \vx_0)\Vert p_\vtheta (\vx_{t-1}|\vx_t))\)是優化的重點。在一般的MHVAE中,由於encoder可能是可學習的任意的函數,這個優化很難實現。而在VDM中,我們將encoder設置為固定的線性高斯模型,應用重參數化技巧,encoder過程可以重寫為: \[ \vx_t = \sqrt{\alpha_t} \vx_{t-1} + \sqrt{1 - \alpha_t} \vec\epsilon, \text{其中}\vec\epsilon \sim \mathcal N(\vec\epsilon; \mZero, \mI) \] 這是一個遞歸的過程。那麼對於任意的\(t\),\(\vx_t\sim q(\vx_t|\vx_0)\)可以寫為(式子中所有的\(\vec \epsilon\)都獨立同分佈地採樣於\(\mathcal N(\vec \epsilon;\mZero, \mI)\)): \[ \begin{aligned} \vx_t &= \sqrt{\alpha_t} \vx_{t-1} + \sqrt{1 - \alpha_t} \vec\epsilon^*_{t-1} \\ &= \sqrt{\alpha_t}(\sqrt{\alpha_{t-1}}\vx_{t-2} + \sqrt{1 - \alpha_{t-1}}\vec\epsilon^*_{t-2}) + \sqrt{1 - \alpha_t} \vec\epsilon^*_{t-1}\\ &= \sqrt{\alpha_t\alpha_{t-1}} \vx_{t-2} + \sqrt{\alpha_t - \alpha_t \alpha_{t-1}} \vec\epsilon^*_{t-2} + \sqrt{1 - \alpha_t}\vec\epsilon^*_{t-1} \\ &\text{兩個高斯變量相加} \\ &= \sqrt{\alpha_t\alpha_{t-1}} \vx_{t-2} + \sqrt{\sqrt{\alpha_t - \alpha_t \alpha_{t-1}}^2 + \sqrt{1 - \alpha_t}^2} \vec\epsilon_{t-2} \\ &= \sqrt{\alpha_t\alpha_{t-1}} \vx_{t-2} + \sqrt{\alpha_t - \alpha_t\alpha_{t-1} + 1 - \alpha_t}\vec\epsilon_{t-2} \\ &= \sqrt{\alpha_t \alpha_{t-1}} \vx_{t-2} + \sqrt{1 - \alpha_t\alpha_{t-1}} \vec\epsilon_{t-2} \\ &= \dots \\ &= \sqrt{\prod_{i=1}^t \alpha_i} \vx_0 + \sqrt{1 - \prod_{i=1}^t\alpha_i}\vec\epsilon_0 \\ &= \sqrt{\bar{\alpha}_t} \vx_0 + \sqrt{1 - \bar{\alpha}_t} \vec\epsilon_0 \\ &\sim \mathcal N(\vx_t; \sqrt{\bar\alpha_t} \vx_0, (1 - \bar\alpha_t)\mI) \end{aligned} \tag{8}\] 這裡應用了“兩個獨立高斯隨機變量之和仍服從高斯分佈,其均值為原分佈均值之和,方差為原分佈方差之和”這個知識點。
這樣,我們就發現\(q(\vx_{t}|\vx_0)\)也服從高斯分佈,知道了它的表達式。我們也可由此得知\(q(\vx_{t-1}|\vx_0)\)的表達式。
接下來的工作就是得到\(q(\vx_{t-1}|\vx_t, \vx_0)\): \[ \begin{aligned} q(\vx_{t-1}|\vx_t, \vx_0) &= \frac{q(\vx_t|\vx_{t-1}, \vx_0) q(\vx_{t-1}|\vx_0)}{q(\vx_t|\vx_0)} \\ &= \frac{\mathcal N(\vx_t; \sqrt{\alpha_t}\vx_{t-1}, (1 - \alpha_t)\mI)\mathcal N(\vx_{t-1}; \sqrt{\bar{\alpha}_{t-1}}\vx_0, (1 - \bar\alpha_{t-1})\mI)}{\mathcal N(\vx_t;\sqrt{\bar\alpha_t}\vx_0, (1 - \bar\alpha_t)\mI)} \\ &\propto \exp\left\{-\left[ \frac{(\vx_t - \sqrt{\alpha_t}\vx_{t-1})^2}{2(1 - \alpha_t)} + \frac{(\vx_{t-1} - \sqrt{\bar\alpha_{t-1}}\vx_0)^2}{2(1 - \bar\alpha_{t-1})} - \frac{(\vx_t - \sqrt{\bar\alpha_t}\vx_0)^2}{2(1 - \bar\alpha_t)} \right] \right\} \\ &= \exp\left\{-\frac{1}{2}\left[ \frac{(\vx_t - \sqrt{\alpha_t}\vx_{t-1})^2}{1 - \alpha_t} + \frac{(\vx_{t-1} - \sqrt{\bar\alpha_{t-1}}\vx_0)^2}{1 - \bar\alpha_{t-1}} - \frac{(\vx_t - \sqrt{\bar\alpha_t}\vx_0)^2}{1 - \bar\alpha_t} \right] \right\} \\ & \text{把與}\vx_t,\vx_0\text{有關的常數項摘出來} \\ &= \exp\left\{ -\frac{1}{2}\left[\frac{-2\sqrt{\alpha_t} \vx_t \vx_{t-1} + \alpha_t \vx^2_{t-1}}{1 - \alpha_t} + \frac{\vx^2_{t-1} - 2\sqrt{\bar\alpha_{t-1}}\vx_{t-1}\vx_0}{1 - \bar\alpha_{t-1}} + C(\vx_t, \vx_0)\right] \right\} \\ &\propto \exp\left\{-\frac{1}{2} \left[ -\frac{2\sqrt{\alpha_t}\vx_t\vx_{t-1}}{1 - \alpha_t} + \frac{\alpha_t\vx^2_{t-1}}{1 - \alpha_t} + \frac{\vx_{t-1}^2}{1 - \bar\alpha_{t-1}} - \frac{2\sqrt{\bar\alpha_{t-1}}\vx_{t-1}\vx_0}{1 - \bar\alpha_{t-1}} \right]\right\} \\ &= \exp\left\{-\frac{1}{2} \left[(\frac{\alpha_t}{1 - \alpha_t} + \frac{1}{1 - \bar\alpha_{t-1}})\vx^2_{t-1} - 2 \left(\frac{\sqrt{\alpha_t}\vx_t}{1 - \alpha_t} + \frac{\sqrt{\bar\alpha_{t-1}}\vx_0}{1 - \bar\alpha_{t-1}}\right)\vx_{t-1}\right] \right\}\\ &= \exp\left\{-\frac{1}{2} \left[\frac{\alpha_t(1 - \bar\alpha_{t-1}) + 1 - \alpha_t}{(1 - \alpha_t)(1 - \bar\alpha_{t-1})}\vx^2_{t-1} - 2 \left(\frac{\sqrt{\alpha_t}\vx_t}{1 - \alpha_t} + \frac{\sqrt{\bar\alpha_{t-1}}\vx_0}{1 - \bar\alpha_{t-1}}\right)\vx_{t-1}\right] \right\}\\ &\text{注意}\alpha_t \bar\alpha_{t-1} = \bar\alpha_t \\ &= \exp\left\{-\frac{1}{2} \left[\frac{1 - \bar\alpha_t}{(1 - \alpha_t)(1 - \bar\alpha_{t-1})}\vx^2_{t-1} - 2 \left(\frac{\sqrt{\alpha_t}\vx_t}{1 - \alpha_t} + \frac{\sqrt{\bar\alpha_{t-1}}\vx_0}{1 - \bar\alpha_{t-1}}\right)\vx_{t-1}\right] \right\}\\ &= \exp\left\{ -\frac{1}{2} \left(\frac{1 - \bar\alpha_t}{(1 - \alpha_t)(1 - \bar\alpha_{t-1})}\right)\left[\vx_{t-1}^2 - 2 \frac{ \left(\frac{\sqrt{\alpha_t}\vx_t}{1 - \alpha_t} + \frac{\sqrt{\bar\alpha_{t-1}}\vx_0}{1 - \bar\alpha_{t-1}}\right)}{\frac{1 - \bar\alpha_t}{(1 - \alpha_t)(1 - \bar\alpha_{t-1})}}\vx_{t-1}\right] \right\}\\ &= \exp\left\{ -\frac{1}{2} \left(\frac{1 - \bar\alpha_t}{(1 - \alpha_t)(1 - \bar\alpha_{t-1})}\right)\left[\vx_{t-1}^2 - 2 \frac{ \left(\frac{\sqrt{\alpha_t}\vx_t}{1 - \alpha_t} + \frac{\sqrt{\bar\alpha_{t-1}}\vx_0}{1 - \bar\alpha_{t-1}}\right)(1 - \alpha_t)(1 - \bar\alpha_{t-1})}{1 - \bar\alpha_t}\vx_{t-1}\right] \right\}\\ &= \exp\left\{ -\frac{1}{2} \left(\frac{1}{\frac{(1 - \alpha_t)(1 - \bar\alpha_{t-1})}{1 - \bar\alpha_t}}\right)\left[\vx_{t-1}^2 - 2 \frac{\sqrt{\alpha_t}(1-\bar\alpha_{t-1})\vx_t + \sqrt{\bar\alpha_{t-1}}(1 - \alpha_t)\vx_0}{1 - \bar\alpha_t}\vx_{t-1}\right] \right\}\\ &\propto \mathcal N(\vx_{t-1}; \underbrace{\frac{\sqrt{\alpha_t}(1 - \bar\alpha_{t-1})\vx_t + \sqrt{\bar\alpha_{t-1}}(1 - \alpha_t)\vx_0}{1 - \bar\alpha_t}}_{\vec\mu_q(\vx_t, \vx_0)},\underbrace{\frac{(1 - \alpha_t)(1 - \bar\alpha_{t-1})}{1 - \bar\alpha_t}\mI}_{\mathbf\Sigma_q(t)}) \end{aligned} \tag{9}\]
於是我們可以看到\(\vx_{t-1}\sim q(\vx_{t-1}\vert\vx_t,\vx_0)\)服從正態分佈。\(\vec\mu_q(\vx_t, \vx_0)\)是\(\vx_t,\vx_0\)的函數,而\(\mathbf \Sigma_q(t)\)是\(\alpha\)參數的函數。
設\(\mathbf \Sigma_q(t)=\sigma^2_q(t)\mI\),其中 \[ \sigma^2_q(t)= \frac{(1 - \alpha_t)(1 - \bar\alpha_{t-1})}{1 - \bar\alpha_t} \tag{10}\]
我們要讓\(p_\vtheta(\vx_{t-1}|\vx_t)\)盡可能接近\(q(\vx_{t-1}|\vx_t,\vx_0)\)。既然\(q(\vx_{t-1}|\vx_t,\vx_0)\)是高斯分佈,我們也可以用高斯分佈建模\(p_\vtheta(\vx_{t-1}|\vx_t)\)。不妨讓它的協方差矩陣也為\(\mathbf \Sigma_q(t)=\sigma^2_q(t)\mI\),均值則設爲\(\vec\mu_\vtheta(\vx_t, t)\),是\(\vx_t\)的函數,但不依賴\(\vx_0\),畢竟decoder無法獲得\(\vx_0\)的真值。
要讓兩個高斯分佈相近,需要考慮它們的KL散度: \[ D_{\text{KL}} (\mathcal N(\vx;\vec\mu_\vx, \mathbf\Sigma_x)\Vert \mathcal N(\vec y;\vec\mu_y,\mathbf \Sigma_y)) = \frac{1}{2}\left[ \log \frac{|\mathbf \Sigma_y|}{|\mathbf \Sigma_x|} - d + \text{tr}(\mathbf \Sigma_y^{-1}\mathbf \Sigma_x) + (\vec\mu_y-\vec\mu_x)^T\mathbf \Sigma_y^{-1}(\vec\mu_y-\vec\mu_x) \right] \] 這裏\(d\)是數據的維度。
因爲我們將協方差矩陣設置為相同的,因此上式將只與均值的差有關。 \[ \begin{aligned} &\argmin{\vtheta} D_\text{KL}(q(\vx_{t-1}|\vx_t,\vx_0)\Vert p_\vtheta(\vx_{t-1}|\vx_t)) \\ =~& \argmin{\vtheta} D_\text{KL}(\mathcal N(\vx_{t-1};\vec\mu_q,\mathbf\Sigma_q(t))\Vert\mathcal N(\vx_{t-1}; \vec\mu_\vtheta, \mathbf\Sigma_q(t))) \\ =~&\argmin{\vtheta} \frac{1}{2} \left[\log\frac{|\mSigma_q(t)|}{|\mSigma_q(t)|} - d + \text{tr}(\mSigma_q(t)^{-1}\mSigma_q(t)) + (\vec\mu_\vtheta -\vec\mu_q)^T\mSigma_q(t)^{-1}(\vec\mu_\vtheta -\vec\mu_q)\right]\\ =~& \argmin{\vtheta} \frac{1}{2}[\log 1 - d + d+ (\vec\mu_\vtheta - \vec\mu_q)^T\mSigma_q(t)^{-1}(\vec\mu_\vtheta -\vec\mu_q)]\\ =~&\argmin{\vtheta} \frac{1}{2}\left[ (\vec\mu_\vtheta - \vec\mu_q)^T\mSigma_q(t)^{-1}(\vec\mu_\vtheta -\vec\mu_q) \right]\\ =~&\argmin{\vtheta} \frac{1}{2} \left[(\vec\mu_\vtheta - \vec\mu_q)^T(\sigma_q^2(t)\mI)^{-1}(\vec\mu_\vtheta -\vec\mu_q) \right] \\ =~&\argmin{\vtheta} \frac{1}{2\sigma^2(t)}\left[\left\Vert \vec\mu_\vtheta - \vec\mu_q \right\Vert_2^2\right] \end{aligned} \] 上面的式子中,\(\vec\mu_\vtheta\)是\(\vec\mu_q(\vec x_t, \vx_0)\)的縮寫, \[ \vec\mu_q(\vx_t,\vx_0)=\frac{\sqrt{\alpha_t}(1 - \bar\alpha_{t-1})\vx_t + \sqrt{\bar\alpha_{t-1}}(1 - \alpha_t)\vx_0}{1 - \bar\alpha_t} \]
而\(\vec\mu_\vtheta\)是\(\vec\mu_\vtheta (\vx_t, t)\)的縮寫。爲了讓\(\vec\mu_\vtheta\)趨近\(\vec\mu_q\),可以這樣設計: \[ \vec\mu_\vtheta(\vx_t, t) = \frac{\sqrt{\alpha_t}(1 - \bar\alpha_{t-1})\vx_t + \sqrt{\bar\alpha_{t-1}}(1 - \alpha_t)\hat \vx_\vtheta(\vx_t, t)}{1 - \bar\alpha_t} \] 其中\(\vx_\vtheta(\vx_t, t)\)是基於\(\vx_t\)和\(t\)對\(\vx_0\)做出的預測。綜合以上推理,優化問題變爲: \[ \begin{aligned} &\argmin \vtheta D_\text{KL} (q(\vx_{t-1}|\vx_t, \vx_0) \Vert p_\vtheta(\vx_{t-1}|\vx_t)) \\ =~&\argmin \vtheta D_\text{KL} (\mathcal N(\vx_{t-1}; \vec\mu_q, \mSigma_q(t)) \Vert \mathcal N(\vx_{t-1}; \vec\mu_\vtheta, \mSigma_q(t)))\\ =~&\argmin \vtheta \frac{1}{2\sigma_q^2(t)}\left[\left\Vert {\sqrt{\alpha_t}(1 - \bar\alpha_{t-1})\vx_t + \sqrt{\bar\alpha_{t-1}}(1 - \alpha_t)\hat\vx_\vtheta(\vx_t, t) \over 1 - \bar\alpha_t} - {\sqrt{\alpha_t} (1 - \bar\alpha_{t-1})\vx_t + \sqrt{\bar\alpha_{t-1}}(1 - \alpha_t)\vx_0\over 1 - \bar\alpha_t}\right\Vert_2^2\right] \\ =~&\argmin \vtheta \frac{1}{2\sigma_q^2(t)} \left[\left\Vert {\sqrt{\bar\alpha_{t-1}}(1 - \alpha_t)\hat\vx_\vtheta(\vx_t, t) \over 1 - \bar\alpha_t} - {\sqrt{\bar\alpha_{t-1}(1 - \alpha_t)\vx_0 \over 1 - \bar\alpha_t}} \right\Vert_2^2 \right]\\ =~&\argmin\vtheta \frac{1}{2\sigma_q^2(t)} \left[\left\Vert {\sqrt{\bar\alpha_{t-1}}(1 - \alpha_t)\over 1 - \bar\alpha_t}(\hat\vx_\vtheta(\vx_t, t) - \vx_0) \right\Vert_2^2\right]\\ =~&\argmin\vtheta \frac{1}{2\sigma_q^2(t)} {{\bar\alpha_{t-1}}(1 - \alpha_t)^2\over (1 - \bar\alpha_t)^2} \left[\left\Vert \hat\vx_\vtheta(\vx_t, t) - \vx_0 \right\Vert_2^2\right] \end{aligned} \tag{11}\] 於是,VDM的優化問題可以歸結爲用一個神經網絡從帶噪聲的圖像中恢復原始圖像。
對公式 7中的求和項的優化可以近似爲對在所有時間步\(t\)上如下期望值的優化: \[ \argmin\vtheta \E _{t\sim U\left\{2, T\right\}}\left[ \E_{q(\vx_t|\vx_0)} \left[ D_\text{KL}(q(\vx_{t-1}|\vx_t, \vx_0)\Vert p_\vtheta(\vx_{t-1}|\vx_t)) \right] \right] \]
7 學習噪聲的參數
本節討論影響VDM噪聲的參數\(\alpha_t\)要如何學習得到。比較容易想到的辦法是使用以\(\vec\eta\)為參數的模型\(\hat \alpha_{\vec\eta}(t)\)作預測。這樣做是低效的,因爲推理時,你需要在每一步\(t\)都預測對應的\(\bar\alpha_t\)。當然你可以提前把計算結果存下來。但是下文將介紹另一種方法。
將公式 10帶入公式 11,我們得到: \[ \begin{aligned} \frac{1}{2\sigma^2_q(t)}\frac{\bar\alpha_{t-1}(1 - \alpha_t)^2}{(1 - \bar\alpha_t)^2}\left[\left\Vert \hat\vx_\vtheta(\vx_t, t) - \vx_0\right\Vert_2^2\right] &= \frac{1}{2\frac{(1 - \alpha_t)(1- \bar\alpha_{t-1})}{1 - \bar\alpha_t}} \frac{\bar\alpha_{t-1}(1 - \alpha_t)^2}{(1 - \bar\alpha_t)^2} \left[\left\Vert \hat\vx_\vtheta(\vx_t, t) - \vx_0\right\Vert_2^2\right] \\ &= \frac{1}{2} \frac{1 - \bar\alpha_t}{(1 - \alpha_t)(1 - \bar\alpha_{t-1})} \frac{\bar\alpha_{t-1}(1 - \alpha_t)^2}{(1 - \bar\alpha_t)^2} \left[\Vert \hat\vx_\vtheta(\vx_t, t) - \vx_0\Vert_2^2\right] \\ &= \frac{1}{2} \frac{\bar\alpha_{t-1}(1 - \alpha_t)}{(1 - \bar\alpha_{t-1})(1 - \bar\alpha_t)} \left[\Vert \hat\vx_\vtheta(\vx_t, t) - \vx_0\Vert_2^2\right]\\ &= \frac{1}{2} \frac{\bar\alpha_{t-1} - \bar\alpha_t}{(1 - \bar\alpha_{t-1})(1 - \bar\alpha_t)} \left[\Vert \hat\vx_\vtheta(\vx_t, t) - \vx_0\Vert_2^2\right] \\ &= \frac{1}{2} \frac{\bar\alpha_{t-1} -\bar\alpha_{t-1}\bar\alpha_t +\bar\alpha_{t-1}\bar\alpha_t - \bar\alpha_t}{(1 - \bar\alpha_{t-1})(1 - \bar\alpha_t)} \left[\Vert \hat\vx_\vtheta(\vx_t, t) - \vx_0\Vert_2^2\right] \\ &= \frac{1}{2} \frac{\bar\alpha_{t-1}(1 - \bar\alpha_t) - \bar\alpha_t(1 - \bar\alpha_{t-1})}{(1 - \bar\alpha_{t-1})(1 - \bar\alpha_t)} \left[\Vert \hat\vx_\vtheta(\vx_t, t) - \vx_0\Vert_2^2\right] \\ &= \frac{1}{2} \left(\frac{\bar\alpha_{t-1}}{1 - \bar\alpha_{t-1}} - \frac{\bar\alpha_t}{1 - \bar\alpha_t}\right) \left[\Vert \hat\vx_\vtheta(\vx_t, t) - \vx_0\Vert_2^2\right] \\ \end{aligned} \tag{12}\]
回想起\(q(\vx_t|\vx_0)\)是形為\(\mathcal N(\vx_t; \sqrt{\bar\alpha_t}\vx_0, (1 - \bar\alpha_t)\mI)\),根據SNR(信噪比)的定義,\(\text{SNR} = \frac{\mu^2}{\sigma^2}\),時間步\(t\)的SNR為: \[ \text{SNR}(t) = \frac{\bar\alpha_t}{1 - \bar\alpha_t} \] 那麽公式 12可以進一步簡化爲: \[ \frac{1}{2\sigma^2_q(t)}\frac{\bar\alpha_{t-1}(1 - \alpha_t)^2}{(1 - \bar\alpha_t)^2}\left[\left\Vert \hat\vx_\vtheta(\vx_t, t) - \vx_0\right\Vert_2^2\right] = \frac{1}{2}(\text{SNR}(t-1)-\text{SNR}(t)) \left[\Vert \hat\vx_\vtheta(\vx_t, t) - \vx_0\Vert_2^2\right] \] 在VDM中,SNR應該隨著時間步\(t\)增加而增加,因爲\(\vx_t\)會隨著\(t\)增加,從原圖逐漸變成標準正態分佈。
那麽不妨將SNR函數設計爲 \[ \text{SNR}(t) = \frac{\bar\alpha_t}{1 - \bar\alpha_t}= \exp(-\omega_{\vec\eta}(t)) \] 所以 \[ \bar\alpha_t = \text{sigmoid}(-\omega_\vec\eta(t)) \] 其中\(\vec\eta\)是可學習的模型參數。
8 VDM的三種等效形式
如前文所述,VDM可以設計爲從\(\vx_t\)預測\(\vx_0\)的模型。但是,VDM還有兩種其它等效形式。
重寫公式 8的重參數化技巧,得到: \[ \vx_0 = \frac{\vx_t - \sqrt{1 - \bar\alpha_t}\vec\epsilon_0}{\sqrt{\bar\alpha_t}} \] 代入公式 9中得到的\(\vec\mu_q(\vx_t, \vx_0)\),得到 \[ \begin{aligned} \vec\mu_q(\vx_t, \vx_0) &= \frac{\sqrt{\alpha_t}(1 - \bar\alpha_{t-1})\vx_t + \sqrt{\bar\alpha_{t-1}}(1 - \alpha_t)\vx_0}{1 - \bar\alpha_t} \\ &= \frac{\sqrt{\alpha_t}(1 - \bar\alpha_{t-1})\vx_t + \sqrt{\bar\alpha_{t-1}}(1 - \alpha_t) \frac{\vx_t - \sqrt{1 - \bar\alpha_t}\vec\epsilon_0}{\sqrt{\bar\alpha_t}}}{1 - \bar\alpha_t} \\ &= \frac{\sqrt{\alpha_t}(1 - \bar\alpha_{t-1})\vx_t + (1 - \alpha_t) \frac{\vx_t - \sqrt{1 - \bar\alpha_t}\vec\epsilon_0}{\sqrt{\alpha_t}}}{1 - \bar\alpha_t} \\ &= \frac{\sqrt{\alpha_t}(1 - \bar\alpha_{t-1})\vx_t}{1 - \bar\alpha_t} + \frac{(1 - \alpha_t)\vx_t}{(1 - \bar\alpha_t)\sqrt{\alpha_t}} - \frac{(1 - \alpha_t)\sqrt{1 - \bar\alpha_t}\vec\epsilon_0}{(1 - \bar\alpha_t)\sqrt{\alpha_t}} \\ &= \left(\frac{\sqrt{\alpha_t}(1 - \bar\alpha_{t-1})}{1 - \bar\alpha_t} + \frac{(1 - \alpha_t)}{(1 - \bar\alpha_t)\sqrt{\alpha_t}}\right)\vx_t - \frac{(1 - \alpha_t)\sqrt{1 - \bar\alpha_t}}{(1 - \bar\alpha_t)\sqrt{\alpha_t}} \vec\epsilon_0\\ &= \left(\frac{\alpha_t(1 - \bar\alpha_{t-1})}{(1 - \bar\alpha_t)\sqrt{\alpha_t}} + \frac{(1 - \alpha_t)}{(1 - \bar\alpha_t)\sqrt{\alpha_t}}\right)\vx_t - \frac{1 - \alpha_t}{\sqrt{1 - \bar\alpha_t}\sqrt{\alpha_t}} \vec\epsilon_0\\ &= \frac{1 - \bar\alpha_t}{(1 - \bar\alpha_t)\sqrt{\alpha_t}}\vx_t - \frac{1 - \alpha_t}{\sqrt{1 - \bar\alpha_t}\sqrt{\alpha_t}} \vec\epsilon_0\\ &= \frac{1}{\sqrt{\alpha_t}}\vx_t - \frac{1 - \alpha_t}{\sqrt{1 - \bar\alpha_t}\sqrt{\alpha_t}} \vec\epsilon_0\\ \end{aligned} \] 因此,另一種\(\vec\mu_\vtheta(\vx_t, t)\)的等效設計是: \[ \vec\mu_\vtheta(\vx_t, t) = \frac{1}{\sqrt{\alpha_t}} \vx_t - \frac{1-\alpha_t}{\sqrt{1 - \bar\alpha_t}\sqrt{\alpha_t}}\hat{\vec\epsilon}_\vtheta(\vx_t, t) \] 對應的,優化問題就變爲: \[ \begin{aligned} &\argmin \vtheta D_\text{KL} (q(\vx_{t-1}|\vx_t, \vx_0)\Vert p_\vtheta(\vx_{t-1}|\vx_t)) \\ =~&\argmin\vtheta D_\text{KL}(\mathcal N(\vx_{t-1};\vec\mu_q, \mat\Sigma_q(t))\Vert \mathcal N(\vx_{t-1};\vec \mu_\vtheta, \mat\Sigma_q(t)))\\ =~&\argmin\vtheta \frac{1}{2\sigma_q^2(t)}\left[\left\Vert \frac{1}{\sqrt{\alpha_t}}\vx_t - \frac{1 - \alpha_t}{\sqrt{1 - \bar\alpha_t}\sqrt{\alpha_t}} \hat{\vec\epsilon}_\vtheta(\vx_t, t) - \frac{1}{\sqrt{\alpha_t}}\vx_t + \frac{1 - \alpha_t}{\sqrt{1 - \bar\alpha_t}\sqrt{\alpha_t}}\vec\epsilon_0\right\Vert_2^2\right] \\ =~&\argmin\vtheta \frac{1}{2\sigma_q^2(t)}\left[\left\Vert \frac{1 - \alpha_t}{\sqrt{1 - \bar\alpha_t}\sqrt\alpha_t}\vec\epsilon_0 - \frac{1 - \alpha_t}{\sqrt{1 - \bar\alpha_t}\sqrt{\alpha_t}}\hat{\vec\epsilon}_\vtheta(\vx_t, t)\right\Vert_2^2\right] \\ =~&\argmin\vtheta \frac{1}{2\sigma_q^2(t)}\left[\left\Vert \frac{1 - \alpha_t}{\sqrt{1 - \bar\alpha_t}\sqrt\alpha_t}(\vec\epsilon_0 - \hat{\vec\epsilon}_\vtheta(\vx_t, t))\right\Vert_2^2\right] \\ =~&\argmin\vtheta \frac{1}{2\sigma_q^2(t)}\frac{(1 - \alpha_t)^2}{(1 - \bar\alpha_t)\alpha_t}\left[\left\Vert \vec\epsilon_0 - \hat{\vec\epsilon}_\vtheta(\vx_t, t)\right\Vert_2^2\right] \\ \end{aligned} \] 式子中\(\hat{\vec\epsilon}_\vtheta\)是用於預測噪聲\(\vec\epsilon_0\sim\mathcal N(\vec\epsilon; \vec 0, \mI)\),從而將\(\vec x_t\)恢復為\(\vec x_0\)的模型。
這樣,我們看到理論上預測噪聲和預測原始圖像在理論上是等價的。但是許多工作表明實際上預測噪聲效果更好。
第三種等價形式的推導要用到特威迪公式(Tweedies’s Formula)。對於高斯變量\(\vec z\sim \mathcal N(\vec z; \vec \mu_z, \mat \Sigma_z)\),特威迪公式表明: \[ \E[\vec\mu_{\vec z}|\vec z] = \vec z + \Sigma_z\nabla_{\vec z}\log p(\vec z) \] 已知\(q(\vec x_t|\vec x_0)=\mathcal N(\vec x_t; \sqrt{\bar\alpha_t} \vec x_0, (1 - \bar\alpha_t)\mI)\),那麽,根據特威迪公式有 \[ \E[\vec\mu_{\vx_t}|\vx_t] = \vx_t + (1 - \bar\alpha_t)\nabla_{\vx_t}\log p(\vx_t) \] 后面为了方便,将\(\nabla_{\vx_t} \log p(\vx_t)\)简写为\(\nabla \log p(\vx_t)\). 已知\(\vec\mu_{\vx_t}=\sqrt{\bar\alpha_t} \vx_0\),因此 \[ \begin{aligned} \sqrt{\bar\alpha_t}\vx_0 = \vx_t + (1 - \bar\alpha_t)\nabla\log p(\vx_t)\\ \therefore \vx_0 = \frac{\vx_t + (1 - \bar\alpha_t)\nabla\log p(\vx_t)}{\sqrt{\bar\alpha_t}} \end{aligned} \] 再次将\(\vx_0\)代入ground-truth的去噪过程公式 9,得到: \[ \begin{aligned} \vec\mu_q(\vx_t, \vx_0) &= \frac{\sqrt{\alpha_t}(1 - \bar\alpha_{t-1})\vx_t + \sqrt{\bar\alpha_{t-1}}(1 - \alpha_t)\vx_0}{1 - \bar\alpha_t}\\ &= \frac{\sqrt{\alpha_t}(1 - \bar\alpha_{t-1})\vx_t + \sqrt{\bar\alpha_{t-1}}(1 - \alpha_t)\frac{\vx_t + (1 - \bar\alpha_t)\nabla\log p(\vx_t)}{\sqrt{\bar\alpha_t}}}{1 - \bar\alpha_t}\\ &= \frac{\sqrt{\alpha_t}(1 - \bar\alpha_{t-1})\vx_t +(1 - \alpha_t)\frac{\vx_t + (1 - \bar\alpha_t)\nabla\log p(\vx_t)}{\sqrt{\alpha_t}}}{1 - \bar\alpha_t}\\ &= \frac{\sqrt{\alpha_t}(1 - \bar\alpha_{t-1})\vx_t }{1 - \bar\alpha_t} + \frac{(1 - \alpha_t)\vx_t}{(1 - \bar\alpha_t)\sqrt{\alpha_t}} + \frac{(1 - \alpha_t)(1 - \bar\alpha_t)\nabla\log p(\vx_t)}{(1 - \bar\alpha_t)\sqrt{\alpha_t}}\\ &= \left( \frac{\sqrt{\alpha_t}(1 - \bar\alpha_{t-1}) }{1 - \bar\alpha_t} + \frac{(1 - \alpha_t)}{(1 - \bar\alpha_t)\sqrt{\alpha_t}} \right)\vx_t + \frac{1 - \alpha_t}{\sqrt{\alpha_t}}\nabla\log p(\vx_t)\\ &= \left( \frac{\alpha_t(1 - \bar\alpha_{t-1}) }{(1 - \bar\alpha_t)\sqrt{\alpha_t}} + \frac{(1 - \alpha_t)}{(1 - \bar\alpha_t)\sqrt{\alpha_t}} \right)\vx_t + \frac{1 - \alpha_t}{\sqrt{\alpha_t}}\nabla\log p(\vx_t)\\ &= \frac{1 - \bar\alpha_t}{(1 - \bar\alpha_t)\sqrt{\alpha_t}} \vx_t + \frac{1 - \alpha_t}{\sqrt{\alpha_t}}\nabla\log p(\vx_t)\\ &= \frac{1}{\sqrt{\alpha_t}} \vx_t + \frac{1 - \alpha_t}{\sqrt{\alpha_t}}\nabla\log p(\vx_t) \\ \end{aligned} \] 再一次,类似的将\(\vec\mu_\vtheta(\vx_t, t)\)设计为: \[ \vec\mu_\vtheta(\vx_t, t)=\frac{1}{\sqrt{\alpha_t}} \vx_t + \frac{1 - \alpha_t}{\sqrt{\alpha_t}} \vec s_\vtheta(\vx_t, t) \] 对应的优化过程变成: \[ \begin{aligned} &\argmin \vtheta D_\text{KL} (q(\vx_{t-1}|\vx_t, \vx_0)\Vert p_\vtheta(\vx_{t-1}\vert \vx_t))\\ =~& \argmin\vtheta D_\text{KL}(\mathcal N(\vx_{t-1}; \vec \mu_q, \mSigma_q(t))\Vert \mathcal N(\vx_{t-1};\vec\mu_\vtheta, \mSigma_a(t))) \\ =~& \argmin\vtheta \frac{1}{2\sigma_q^2(t)} \left[\left\Vert \frac{1}{\sqrt{\alpha_t}} \vx_t + \frac{1 - \alpha_t}{\sqrt{\alpha_t}} \vec s_\vtheta(\vx_t, t) - \frac{1}{\sqrt{\alpha_t}} \vx_t - \frac{1 - \alpha_t}{\sqrt{\alpha_t}}\nabla\log p(\vx_t) \right\Vert_2^2\right] \\ =~& \argmin\vtheta \frac{1}{2\sigma_q^2(t)} \left[\left\Vert \frac{1 - \alpha_t}{\sqrt{\alpha_t}} \left(\vec s_\vtheta(\vx_t, t) - \nabla\log p(\vx_t) \right) \right\Vert_2^2\right] \\ =~& \argmin\vtheta \frac{1}{2\sigma_q^2(t)} \frac{(1 - \alpha_t)^2}{\alpha_t} \left[\Vert \vec s_\vtheta(\vx_t, t) - \nabla\log p(\vx_t) \Vert_2^2\right] \end{aligned} \tag{13}\] 式子中\(\vec s_\vtheta(\vx_t, t)\)是一个预测分数函数(score function)\(\nabla \log p(\vx_t)\)的模型,本質上是在時間步\(t\)對於\(\vx_t\)的梯度。
分數函數\(\nabla \log p(\vx_t)\)與噪聲\(\vec\epsilon_0\)顯然很相似。不難看出 \[ \begin{aligned} & \vx_0 = \frac{\vx_t + (1 - \bar\alpha_t)\nabla\log p(\vx_t)}{\sqrt{\bar\alpha_t}} = \frac{\vx_t - \sqrt{1 - \bar\alpha_t}\vec\epsilon_0}{\sqrt{\bar\alpha_t}}\\ & \therefore \nabla \log p(\vx_t) = -\frac{1}{\sqrt{1 - \bar\alpha_t}}\vec\epsilon_0 \end{aligned} \] 結果顯示分數函數和噪聲的區別在於一個由\(t\)決定的係數。直覺上,噪聲是施加于原圖像使其變得 隨機的過程,而我們證明了分數函數通過建模反方向的噪聲來恢復圖像。
總的來看,我們得到了三種等價的優化目標:直接預測\(\vx_0\), 預測噪聲\(\vec \epsilon_0\),預測分數\(\nabla \log p(\vx_t)\).
9 基於分數的生成式模型
我們展示了VDM可以通過優化\(\vec s_\vtheta(\vx_t, t)\),預測\(\nabla \log p(\vx_t)\)來實現。然而這個基於特威迪公式的推導沒有充分展示出設計背後的信息。分數函數(score function)到底是什麼呢。為什麼它值得我們去建模呢。幸運的是我們可以學習另一類生成式模型,基於分數的生成模型(Score-based Generative Models),了解分數函數的意義。我們會看到VDM可以解釋為一種基於分數的生成模型。
在此之前我們先大概了解一下基於能量的模型(energy-based model)。首先,任意概率分佈可以重寫為這樣的形式: \[ p_\vtheta(x) = \frac{1}{Z_\vtheta}e^{-f_\vtheta(\vx)}, \] 其中\(f_\vtheta(\vx)\)就是能量函數。\(Z_\vtheta\)是用於使得\(\int p_\vtheta(\vx)d\vx=1\)成立的係數。\(Z_\vtheta=\int e^{-f_\vtheta(\vx)}d\vx\)並不一定容易計算。如果\(f_\vtheta(\vx)\)很複雜,\(Z_\vtheta\)就容易變得不可解。
一種避免建模\(Z_\vtheta\)的方法是使用\(\vec s_\vtheta(\vx)\)來學習分數函數\(\nabla \log p(\vx)\)。注意到, \[ \begin{aligned} \nabla_\vx\log p_\vtheta(\vx) &= \nabla_\vx \log(\frac{1}{Z_\vtheta}e^{-f_\vtheta(\vx)}) \\ &= \nabla_\vx\log\frac{1}{Z_\vtheta} + \nabla_\vx \log e^{-f_\vtheta(\vx)}\\ &=-\nabla_\vx f_\vtheta(\vx) \\ &\approx \vec s_\vtheta(\vx) \end{aligned} \]
因此\(\nabla_\vx \log p_\vtheta(\vx)\)可以用一個神經網絡近似。分數函數可以通過最小化費雪散度(Fisher Divergence)學習: \[ \mathbb E_{p(x)}\left[ \Vert \vec s_\vtheta(\vx) - \nabla \log p(\vx) \Vert_2^2 \right] \tag{14}\] 分數函數表示的是樣本\(\vx\)在數據空間中往哪個方向移動能最大化其對數似然。
一旦學到了這樣的分數函數,我們就可以用如下的過程來生成樣本: \[ \vx_{i+1}\leftarrow \vx_i + c \nabla \log p(\vx_i) + \sqrt{2c}\vec\epsilon, ~ i=0,1,\dots,K, \] 其中\(\vx_0\)是空間中隨機採樣的一點,噪聲\(\vec\epsilon\sim\mathcal N(\vec\epsilon;\mZero,\mI)\)的作用是防止採樣總是收斂於同一模式。這個採樣過程被稱為朗之萬動力學。
目標函數公式 14需要我們得到真實的分數函數,但在建模複雜分佈(比如真實圖像)的時候是做不到的。score matching技術允許我們在不知道真實分佈的情況下最小化費雪散度。
這就是基於分數的生成式模型的原理。但是原始的基於分數的生成模型有幾個問題:
- 如果數據是高維空間中的低維流形時,分數函數不是良定義的。如果一個點落在低維流形外,這個點的概率為0,對數函數在此無定義。而自然圖像就被認為是一種高維空間中的低維流形。
- 學習到的分數函數在數據的低密度區域可能不準確。因為我們訓練公式 14時,對於見得少或者沒見過的數據,模型收不到很多監督信號。而採樣卻是從空間中的隨機點開始的,不準確的分數函數將導致採樣結果落在非最優點
- 朗之萬動力學採樣不支持對混合分佈的採樣。例如對於 \[ p(\vx) = c_1 p_1(\vx) + c_2 p_2(\vx) \] 從特定位置初始化的採樣點,可能會以均等的機會落到兩個分佈中,即使\(c_1 != c_2\).
而以上幾個缺點可以同時用VDM的方法解決——往數據裡加不同大小的噪聲:
- 高斯噪聲的加入將使得空間中每一點的概率都不為0.
- 高斯噪聲的增大使得空間中每一點在訓練中被採樣到的幾率變得更加均勻。流形的低密度區能得到更好的訓練。
- 逐步加強的高斯噪聲能形成一種“中間態”的分佈,允許我們的採樣能夠遵循分佈的混合係數。具體的,我們可以定義不同時間步下不同的噪聲等級\(\left\{\sigma_t\right\}_{t=1}^T\),並定義每個時間步\(t\)的數據分佈 \[ p_{\sigma_t}(\vx_t) = \int p(\vx) \mathcal N(\vx_t; \vx, \sigma_t^2\mI) d\vx \]
神經網絡\(\vec s_\vtheta(\vx, t)\)將學習不同時間步下的分數函數: \[ \argmin \vtheta \sum_{t=1}^T \lambda(t)\mathbb E_{p_{\sigma_t}(\vx_t)} \left[ \Vert \vec s_\vtheta(\vx, t) - \nabla \log p_{\sigma_t}(\vx_t) \Vert_2^2 \right], \tag{15}\] 其中\(\lambda(t)\geq 0\)是權重係數。在噪聲大的時候,上面的目標函數使模型能夠學習不同分佈模式的比例;在噪聲小的時候,模式逐漸分離,分數函數更精準地學到每個模式的細節。在採樣時,對於每個\(t=T, T-1, \dots, 2, 1\),我們先從高噪音模式開始,然後逐漸降低噪音,直到樣本收斂於某個具體模式。
注意到公式 15和公式 13的形式一致。至此,我們建立起了VDM和基於分數的生成模型之間的聯繫。
從基於分數的生成模型的視角出發,我們還可以發現當MHVAE的時間步數\(T\rightarrow \infty\)時,相當於將離散的隨機過程變為連續的隨機過程,這時可以用SDE(stochastic differential equation)來描述這個過程,而採樣可以通過求逆向的SDE來完成。
10 引導信息
前文只討論了\(p(x)\). 但我們有時會希望用引導信息控制生成的圖像,需要條件概率\(p(\vx|y)\). 這是圖像超分辨率、文生圖模型的基石。
一種自然的方式是在每一時間步加上條件信息,將公式 \[ p(\vx_{0:T}) = p(\vx_T)\prod_{t=1}^T p_\vtheta(\vx_{t-1}|\vx_t) \] 轉變為 \[ p(\vx_{0:T}|y) = p(\vx_T)\prod_{t=1}^T p_\vtheta(\vx_{t-1}|\vx_t, y) \] 然後我們可以預測\(\hat \vx_\vtheta(\vx_t, t, y)\approx \vx_0\),或者\(\hat{\vec\epsilon}_\vtheta(\vx_t, t, y)\approx \vec\epsilon_0\),或者\(\vec s_\vtheta(\vx_t, t, y)\approx \nabla \log p(\vx_t|y)\),從而構造一個VDM。
這種方法可能的缺點是,模型可能忽略或者不充分重視條件信息。
為了解決這個問題,可以使用一些引導技巧。兩種常見的引導技巧包括分類器引導和免分類器引導。
10.1 分類器引導
讓我們從基於分數的生成器的視角來看,我們的目標是學習\(\nabla \log p(\vx_t|y)\). 根據貝葉斯公式, \[ \begin{aligned} \nabla \log p(\vx_t|y) &= \nabla \log \left(\frac{p(\vx_t)p(y|\vx_t)}{p(y)} \right)\\ &= \nabla \log p(\vx_t) + \nabla \log p(y|\vx_t) - \nabla \log p(y) \\ &= \underbrace{\nabla \log p(\vx_t)}_{無條件的分數}+ \underbrace{\nabla \log p(y|\vx_t)}_{對抗梯度} \end{aligned} \tag{16}\] 前面說過\(\nabla\)是\(\nabla_{\vx_t}\)的簡寫,所以\(\nabla \log p(y)=0\).
根據推導的結果,一個由類別為條件的生成模型可以分解為一個無條件生成模型,搭配一個分類器模型。其中分類器用於提供“對抗梯度”,將採樣過程引導到對應類別。
為了更精細地控制採樣,我們可以加一個係數控制對抗梯度的強度,像這樣: \[ \nabla \log p(\vx_t|y) = \nabla \log p(\vx_t) + \gamma \nabla \log p(y|\vx_t) \tag{17}\]
這個引導方法的缺點在於它需要一個額外的分類器。一般的分類器還不行,這個分類器還得適應帶噪聲的輸入。
10.2 無需分類器的引導技巧
為了推導出免分類器的引導方法,可以重新整理公式 16,得到 \[ \nabla \log p(y|\vx_t) = \nabla \log p(\vx_t | y) - \nabla \log p(\vx_t), \] 將其代入公式 17,得到 \[ \begin{aligned} \nabla \log p(\vx_t|y) &= \nabla \log p(\vx_t) + \gamma(\nabla \log p(\vx_t|y) - \nabla \log p(\vx_t)) \\ &= \nabla \log p(\vx_t) + \gamma \nabla \log p(\vx_t| y) - \gamma \nabla \log p(\vx_t)\\ &= \underbrace{\gamma\nabla \log p(\vx_t|y)}_{條件分數} + \underbrace{(1 - \gamma)\nabla\log p(\vx_t)}_{無條件分數} \end{aligned} \] 同樣可以獲得一個使用係數控制梯度方向的方法。實踐中,可以用同一個模型同時學習無條件生成和條件生成,然後在推理時,用上面的式子控制梯度方向。
11 總結
- VAE模型是MHVAE的一個特例
- 介紹了VDM的三種等效優化目標:
- 預測原始圖像
- 預測噪聲
- 預測分數函數
- VDM有以下缺點,值得進一步思考
- VDM沒有遵循或者模擬人類通常生成數據的方式
- VDM不產生可解釋的隱變量。而VAE則有可能產生一些有意義的隱變量。
- 隱變量和原始數據的尺寸被限定為相同的,因此無法學到壓縮的、有意義的隱變量。
- 採樣過程比較昂貴,需要採樣多步。