大家好,生成式大模型的熱度,從今年3月開始已經燃了一個多季度了。在這個季度中,相信大家肯定看過很多AI產生的有趣內容,比如著名的抓捕川普現場與監獄風云 [現在來看不僅是畫得像而已],AI換聲孫燕姿等。這背后都用到了一個強大的模型:Diffusion Model。所以,在這個系列中,我們將從原理到源碼,從基石DDPM到DALLE2,Imagen與Stable Diffusion,通過詳細的圖例和解說,和大家一起來了解擴散模型的奧秘。同時,也會穿插對經典的GAN,VAE等模型的解讀,敬請期待~
本篇將和大家一起解讀擴散模型的基石:DDPM(Denoising Diffusion Probalistic Models)。擴散模型的研究并不始于DDPM,但DDPM的成功對擴散模型的發展起到至關重要的作用。在這個系列里我們也會看到,后續一連串效果驚艷的模型,都是在DDPM的框架上迭代改進而來。所以,我把DDPM放在這個系列的第一篇進行講解。
初讀DDPM論文的朋友,可能有以下兩個痛點:
論文花極大篇幅講數學推導,可是我看不懂。
論文沒有給出模型架構圖和詳細的訓練解說,而這是我最關心的部分。
針對這些痛點,DDPM系列將會出如下三篇文章:
DDPM(模型架構篇):也就是本篇文章。在閱讀源碼的基礎上,本文繪制了詳細的DDPM模型架構圖,同時附上關于模型運作流程的詳細解說。本文不涉及數學知識,直觀幫助大家了解DDPM怎么用,為什么好用。
DDPM(人人都能看懂的數學推理篇):DDPM的數學推理可能是很多讀者頭疼的部分。我嘗試跳出原始論文的推導順序和思路,從更符合大家思維模式的角度入手,把整個推理流程串成一條完整的邏輯線。同樣,我也會配上大量的圖例,方便大家理解數學公式。如果你不擅長數學推導,這篇文章可以幫助你從直覺上了解DDPM的數學有效性;如果你更關注推導細節,這篇文章中也有詳細的推導中間步驟。
DDPM(源碼解讀篇):在前兩篇的基礎上,我們將配合模型架構圖,一起閱讀DDPM源碼,并實操跑一次,觀測訓練過程里的中間結果。
本文目錄如下:
一、DDPM在做一件什么事
二、DDPM訓練流程:Diffusion/Denoise Process
三、DDPM的training與sampling
四、圖解DDPM核心模型架構:UNet
五、文生圖模型的一般公式
六、參考
最后,Megatron源碼解讀的系列沒有停更!只是兩個系列穿插來寫。
一、DDPM在做一件什么事
假設你想做一個以文生圖的模型,你的目的是給一段文字,再隨便給一張圖(比如一張噪聲),這個模型能幫你產出符合文字描述的逼真圖片,例如:
文字描述就像是一個指引(guidance),幫助模型去產生更符合語義信息的圖片。但是,畢竟語義學習是復雜的。我們能不能先退一步,先讓模型擁有產生逼真圖片的能力?
比如說,你給模型喂一堆cyperpunk風格的圖片,讓模型學會cyberpunk風格的分布信息,然后喂給模型一個隨機噪音,就能讓模型產生一張逼真的cyberpunk照片。或者給模型喂一堆人臉圖片,讓模型產生一張逼真的人臉。同樣,我們也能選擇給訓練好的模型喂帶點信息的圖片,比如一張夾雜噪音的人臉,讓模型幫我們去噪。
具備了產出逼真圖片的能力,模型才可能在下一步中去學習語義信息(guidance),進一步產生符合人類意圖的圖片。而DDPM的本質作用,就是學習訓練數據的分布,產出盡可能符合訓練數據分布的真實圖片。所以,它也成為后續文生圖類擴散模型框架的基石。
二、DDPM訓練流程
理解DDPM的目的,及其對后續文生圖的模型的影響,現在我們可以更好來理解DDPM的訓練過程了。總體來說,DDPM的訓練過程分為兩步:
Diffusion Process (又被稱為Forward Process)
Denoise Process(又被稱為Reverse Process)
前面說過,DDPM的目的是要去學習訓練數據的分布,然后產出和訓練數據分布相似的圖片。那怎么“迫使”模型去學習呢?
一個簡單的想法是,我拿一張干凈的圖,每一步(timestep)都往上加一點噪音,然后在每一步里,我都讓模型去找到加噪前圖片的樣子,也就是讓模型學會去噪。這樣訓練完畢后,我再塞給模型一個純噪聲,它不就能一步步幫我還原出原始圖片的分布了嗎?
一步步加噪的過程,就被稱為Diffusion Process;一步步去噪的過程,就被稱為Denoise Process。我們來詳細看這兩步。
2.1 Diffusion Process
Diffusion Process的命名受到熱力學中分子擴散的啟發:分子從高濃度區域擴散至低濃度區域,直至整個系統處于平衡。加噪過程也是同理,每次往圖片上增加一些噪聲,直至圖片變為一個純噪聲為止。整個過程如下:
如圖所示,我們進行了1000步的加噪,每一步我們都往圖片上加入一個高斯分布的噪聲,直到圖片變為一個純高斯分布的噪聲。
我們記:
:總步數
:每一步產生的圖片。其中為原始圖片,為純高斯噪聲
:為每一步添加的高斯噪聲
:在條件下的概率分布。如果你覺得抽象,可以理解成已知,求
那么根據以上流程圖,我們有:
根據公式,為了知道,需要sample好多次噪聲,感覺不太方便,能不能更簡化一些呢?
重參數
我們知道隨著步數的增加,圖片中原始信息含量越少,噪聲越多,我們可以分別給原始圖片和噪聲一個權重來計算:
:一系列常數,類似于超參數,隨著的增加越來越小。
則此時的計算可以設計成:
現在,我們只需要sample一次噪聲,就可以直接得到了。
接下來,我們再深入一些,其實并不是我們直接設定的超參數,它是根據其它超參數推導而來,這個“其它超參數”指:
:一系列常數,是我們直接設定的超參數,隨著T的增加越來越大
則和的關系為:
這樣從原始加噪到加噪,再到加噪,使得轉換成的過程,就被稱為重參數(Reparameterization)。我們會在這個系列的下一篇(數學推導篇)中進一步探索這樣做的目的和可行性。在本篇中,大家只需要從直覺上理解它的作用方式即可。
2.2 Denoise Process
Denoise Process的過程與Diffusion Process剛好相反:給定,讓模型能把它還原到。在上文中我們曾用這個符號來表示加噪過程,這里我們用來表示去噪過程。由于加噪過程只是按照設定好的超參數進行前向加噪,本身不經過模型。但去噪過程是真正訓練并使用模型的過程。所以更進一步,我們用來表示去噪過程,其中表示模型參數,即:
:用來表示Diffusion Process
:用來表示Denoise Process。
講完符號表示,我們來具體看去噪模型做了什么事。如下圖所示,從第T個timestep開始,模型的輸入為與當前timestep 。模型中蘊含一個噪聲預測器(UNet),它會根據當前的輸入預測出噪聲,然后,將當前圖片減去預測出來的噪聲,就可以得到去噪后的圖片。重復這個過程,直到還原出原始圖片為止:
你可能想問:
為什么我們的輸入中要包含time_step?
為什么通過預測噪聲的方式,就能讓模型學得訓練數據的分布,進而產生逼真的圖片?
第二個問題的答案我們同樣放在下一篇(數學推理篇)中進行詳解。而對于第一個問題,由于模型每一步的去噪都用的是同一個模型,所以我們必須告訴模型,現在進行的是哪一步去噪。因此我們要引入timestep。timestep的表達方法類似于Transformer中的位置編碼(可以參考這篇文章),將一個常數轉換為一個向量,再和我們的輸入圖片進行相加。
注意到,UNet模型是DDPM的核心架構,我們將關于它的介紹放在本文的第四部分。
到這里為止,如果不考慮整個算法在數學上的有效性,我們已經能從直覺上理解擴散模型的運作流程了。那么,我們就可以對它的訓練和推理過程來做進一步總結了。
三、DDPM的Training與Sampling過程
3.1 DDPM Training
上圖給出了DDPM論文中對訓練步驟的概述,我們來詳細解讀它。
前面說過,DDPM模型訓練的目的,就是給定time_step和輸入圖片,結合這兩者去預測圖片中的噪聲。
我們知道,在重參數的表達下,第t個時刻的輸入圖片可以表示為:
也就是說,第t個時刻sample出的噪聲,就是我們的噪聲真值。而我們預測出來的噪聲為:
,其中為模型參數,表示預測出的噪聲和模型相關。那么易得出我們的loss為:
我們只需要最小化該loss即可。
由于不管對任何輸入數據,不管對它的任何一步,模型在每一步做的都是去預測一個來自高斯分布的噪聲。因此,整個訓練過程可以設置為:
從訓練數據中,抽樣出一條(即)
隨機抽樣出一個timestep。(即)
隨機抽樣出一個噪聲(即)
計算:
計算梯度,更新模型,重復上面過程,直至收斂
上面演示的是單條數據計算loss的過程,當然,整個過程也可以在batch范圍內做,batch中單條數據計算loss的方法不變。
3.2 DDPM的Sampling
當DDPM訓練好之后,我們要怎么用它,怎么評估它的效果呢?
對于訓練好的模型,我們從最后一個時刻(T)開始,傳入一個純噪聲(或者是一張加了噪聲的圖片),逐步去噪。根據x_tx_{t-1}的關系(上圖的前半部分)。而圖中一項,則不是直接推導而來的,是我們為了增加推理中的隨機性,而額外增添的一項。可以類比于GPT中為了增加回答的多樣性,不是選擇概率最大的那個token,而是在topN中再引入方法進行隨機選擇。
關于和關系的詳細推導,我們也放在數學推理篇中做解釋。
通過上述方式產生的,我們可以計算它和真實圖片分布之間的相似度(FID score:Frechet Inception Distance score)來評估圖片的逼真性。在DDPM論文中,還做了一些有趣的實驗,例如通過“插值(interpolation)”方法,先對兩張任意的真實圖片做Diffusion過程,然后分別給它們的diffusion結果附不同的權重(),將兩者diffusion結果加權相加后,再做Denoise流程,就可以得到一張很有意思的"混合人臉":
到目前為止,我們已經把整個DDPM的核心運作方法講完了。接下來,我們來看DDPM用于預測噪聲的核心模型:UNet,到底長成什么樣。我在學習DDPM的過程中,在網上幾乎找不到關于DDPM UNet的詳細模型解說,或者一張清晰的架構圖,這給我在源碼閱讀過程中增加了難度。所以在讀完源碼并進行實操訓練后,我干脆自己畫一張出來,也借此幫助自己更好理解DDPM。
四、DDPM中的Unet架構
UNet模型最早提出時,是用于解決醫療影響診斷問題的。總體上說,它分成兩個部分:
Encoder
Decoder
在Encoder部分中,UNet模型會逐步壓縮圖片的大小;在Decoder部分中,則會逐步還原圖片的大小。同時在Encoder和Deocder間,還會使用“殘差連接”,確保Decoder部分在推理和還原圖片信息時,不會丟失掉之前步驟的信息。整體過程示意圖如下,因為壓縮再放大的過程形似"U"字,因此被稱為UNet:
那么DDPM中的UNet,到底長什么樣子呢?我們假設輸入為一張32323大小的圖片,來看一下DDPM UNet運作的完整流程:
如圖,左半邊為UNet的Encoder部分,右半邊為UNet的Deocder部分,最下面為MiddleBlock。我們以從上往下數第二行來分析UNet的運作流程。
在Encoder部分的第二行,輸入是一個16*16*64的圖片,它是由上一行最右側32*32*64的圖片壓縮而來(DownSample)。對于這張16*16*64大小的圖片,在引入time_embedding后,讓它們一起過一層DownBlock,得到大小為16*16*128的圖片。再引入time_embedding,再過一次DownBlock,得到大小同樣為16*16*128的圖片。對該圖片做DowSample,就可以得到第三層的輸入,也就是大小為8*8*128的圖片。由此不難知道,同層間只做channel上的變化,不同層間做圖片的壓縮處理。至于每一層channel怎么變,層間size如何調整,就取決于實際訓練中對模型的設定了。Decoder層也是同理。其余的信息可以參見圖片,這里不再贅述。
我們再詳細來看右下角箭頭所表示的那些模型部分,具體架構長什么樣:
4.1 DownBlock和UpBlock
如果你曾在學習DDPM的過程中,困惑time_embedding要如何與圖片相加,Attention要在哪里做,那么這張圖可以幫助你解答這些困惑。TimeEmbedding層采用和Transformer一致的三角函數位置編碼,將常數轉變為向量。Attention層則是沿著channel維度將圖片拆分為token,做完attention后再重新組裝成圖片(注意Attention層不是必須的,是可選的,可以根據需要選擇要不要上attention)。
你可能想問:一定要沿著channel方向拆分圖片為token嗎?我可以選擇VIT那樣以patch維度拆分token,節省計算量嗎?當然沒問題,你可以做各種實驗,這只是提供DDPM對圖片做attention的一種方法。
4.2 DownSample和UpSample
這個模塊很簡單,就是壓縮(Conv)和放大(ConvT)圖片的過程。對ConvT原理不熟悉的朋友們,可以參考這篇文章(https://blog.csdn.net/sinat_29957455/article/details/85558870)。
4.3 MiddleBlock
和DownBlock與UpBlock的過程相似,不再贅述。
到這一步,我們就把DDPM的模型核心給講完啦。在第三篇源碼解讀中,我們會結合這些架構圖,來一起閱讀DDPM training和sampling代碼。
五、文生圖模型的一般公式
講完了DDPM,讓我們再回到開頭,看看最初我們想訓練的那個“以文生圖”模型吧!
當我們擁有了能夠產生逼真圖片的模型后,我們現在能進一步用文字信息去引導它產生符合我們意圖的模型了。通常來說,文生圖模型遵循以下公式(圖片來自李宏毅老師課堂PPT):
Text Encoder:一個能對輸入文字做語義解析的Encoder,一般是一個預訓練好的模型。在實際應用中,CLIP模型由于在訓練過程中采用了圖像和文字的對比學習,使得學得的文字特征對圖像更加具有魯棒性,因此它的text encoder常被直接用來做文生圖模型的text encoder(比如DALLE2)
Generation Model: 輸入為文字token和圖片噪聲,輸出為一個關于圖片的壓縮產物(latent space)。這里通常指的就是擴散模型,采用文字作為引導(guidance)的擴散模型原理,我們將在這個系列的后文中出講解。
Decoder:用圖片的中間產物作為輸入,產出最終的圖片。Decoder的選擇也有很多,同樣也能用一個擴散模型作為Decoder。
5.1 DALLE2
DALLE2就套用了這個公式。它曾嘗試用Autoregressive和Diffusion分別來做Generation Model,但實驗發現Diffusion的效果更好。所以最后它的2和3都是一個Diffusion Model。
Stable Diffusion
大名鼎鼎Stable Diffsuion也能按這個公式進行拆解。
5.3 Imagen
Google的Imagen,小圖生大圖,遵循的也是這個公式。
按這個套路一看,是不是文生圖模型,就不難理解了呢?我們在這個系列后續文章中,也會對這些效果驚艷的模型,進行解讀。
-
源碼
+關注
關注
8文章
633瀏覽量
29138 -
模型
+關注
關注
1文章
3171瀏覽量
48711
原文標題:深入淺出擴散模型系列:基石DDPM(模型架構篇),最詳細的DDPM架構圖解
文章出處:【微信號:GiantPandaCV,微信公眾號:GiantPandaCV】歡迎添加關注!文章轉載請注明出處。
發布評論請先 登錄
相關推薦
評論