大家好,這就為您獻上不知鴿了多久的Flash Attention V2原理解讀。
在V1的講解中,我們通過詳細的圖解和公式推導,一起學習了Flash Attention的整體運作流程。如果大家理解了V1的這塊內容,就會發現V2的原理其實非常簡單:無非是將V1計算邏輯中的內外循環相互交換,以此減少在shared memory上的讀寫次數,實現進一步提速。那當你交換了循環位置之后,在cuda層面就可以配套做一些并行計算優化。這就是V2的整體內容。
總結起來一句話:“交換了循環位置“,雖是短短一句話,卻蘊含著深深的人生哲理:只要基座選得好,回回都有迭代點,年年勇破okr!
回歸正題,本文也分兩個部分進行講解:原理與cuda層面的并行計算。
在閱讀本文前,需要先閱讀V1的講解,本文會沿用V1的表達符號及推演思路。
一、Flash Attention V2整體運作流程
1.1 V1的運作流程
我們先快速回顧一下V1的運作流程:以K,V為外循環,Q為內循環。
,遍歷:
,遍歷:
為了幫助大家更好理解v1中數據塊的流轉過程,在圖中我們畫了6塊O。但實際上最終只有三塊O:。
以為例,它可理解成是由經過某些處理后匯總而來的。進一步說,
我們在外循環j = 0時,先遍歷一次所有的i,在這個階段中我們產出,并將它和一些別的重要數據寫回HBM中
接下來我們進行第二次外循環,即j=1,在這個階段中我們產出。同時我們把和那些重要的數據從HBM傳入shared memory中,然后從shared memory中讀取它們,以配合產出最終的
(關于如何得到的細節我們在V1講解中詳細推導過,這里不再贅述)
在這個過程中,你是不是隱隱覺得有些別扭:
其實都和有關系,那我為什么不以Q為外循環,以KV為內循環做遍歷呢?這樣我不就能避免往shared memory上讀寫中間結果,從而一次性把乃至最終的給算出來?
同時,softmax這個操作也是在row維度上的,所以我固定Q循環KV的方式,更天然符合softmax的特性。
1.2 V2的運作流程
基于1.1中的思想,我們在V2中將原本的內外循環置換了位置(示意圖就不畫了,基本可以對比V1示意圖想象出來)。我們直接來看V2的偽代碼(如果對以下偽代碼符號表示或解讀有疑惑的朋友,最好先看一下V1的講解)。
(1)V2 FWD
現在,想象自己固定住了一塊Q(i),依此循環K和V的分塊(j),在這個想象下我們來解讀這份FWD為代碼。
第8行,計算分塊
第9行:
表示截止到當前分塊(包含當前分塊)為止的rowmax
表示使用當前每行最大值計算歸一化前的(我們在V1中說過,不帶波浪號的P表示(s-rowmax)/rowsum的結果,帶波浪號表示(s-rowmax))
表示截止到當前分塊(包含當前分塊為止)的rowsum
第10行:表示截止到當前分塊(包含當前分塊)為止計算出的O值。由第9和第10行知,當我們固定Q循環KV時,我們每個分塊都是用當前最新的rowmax和rowsum計算的,同理對應的也是用當前最新的rowmax和rowsum計算的。這樣當我們遍歷完所有的KV時,得到的就等于最終全局的結果。相關的證明我們在V1講解中給過,這里不再贅述,只額外提兩點:
可能在有些朋友下載的V2論文中,第十行這里O前面的因子項是,這個公式應該是錯誤的(大家動手推一下就可知,初次看到時讓我困擾了很久)。在作者個人主頁的論文鏈接中,這個typo已經被修正。
你可能已發現這個O的計算中缺少歸一化的一項,這一項其實放到了第12行做統一計算。這也是V2優化的一個點:盡量減少非矩陣的計算,因為在GPU中,非矩陣計算比矩陣計算慢16倍。
比起V1,V2中不用再存每一Q分塊對應的和了。但是在BWD的過程中,我們仍需要來做和的重計算,這樣才能用鏈式求導法則把dQ,dK,dV正常算出來。V2在這里用了一個很巧妙的方法,它只存一個東西(代碼13行,這樣又能進一步減少shared memory的讀寫):,這個等式中小寫的m和l可以理解成是全局的rowmax和rowsum。在接下來BWD的講解中,我們會來看到這一項的妙用。
(2)V2 BWD
一個建議:如果你在閱讀本節中覺得很困惑,一定記得先去看V1的BWD部分,有非常詳細的推導介紹。看完再來看本節就很順暢了。
我們觀察到,在V2 BWD中,內外循環的位置又換回來了,即還是KV外循環,Q內循環,這是為什么呢?
我們知道在BWD的過程中,我們主要是求(為了求它們還需要求中間結果,我們來總結一下這些梯度都需要沿著哪些方向AllReduce:
:沿著i方向做AllReduce,也就是需要每行的結果加總
:沿著i方向做AllReduce,也就是需要每行的結果加總
:沿著j方向做AllReduce,也就是需要每列的結果加總
:只與當前i,j相關
基于此,如果你還是保持Q外循環,KV外循環不變的話,這種操作其實是固定行,遍歷列的,那么在這些梯度中,只有從中受益了,K和V的梯度則進入了別扭的循環(也意味著要往shared memory上寫更多的中間結果);但如果你采用KV外循環,Q內循環,這樣K和V都受益,只有Q獨自別扭,因此是一種更好的選擇。(S和P的計算不受循環變動影響)。
前面說過,在BWD過程中讀寫我們要用全局的重新計算,計算公式如下:
但如此一來,我們就要從shared memory上同時讀取,似乎有點消耗讀寫。所以在V2中,我們只存儲,然后計算:
很容易發現這兩個計算是等價的,但V2的做法節省了讀寫量
好,現在我們就把V2相對于V1在計算原理上的改進介紹完了。接下來我們總結一下V2相對于V1所有的改進點。
二、V2相對V1的改進點
之所以把這塊內容放到“V2整體流程介紹”之后,是想讓大家在先理解V2是怎么做的基礎上,更好體會V2的優點。
總體來說,V2從以下三個方面做了改進:
置換內外循環位置,同時減少非矩陣的計算量。(這兩點我們在第一部分中已給出詳細說明)
優化Attention部分thread blocks的并行化計算,新增seq_len維度的并行,使SM的利用率盡量打滿。這其實也是內外循環置換這個總體思想配套的改進措施
優化thread blocks內部warp級別的工作模式,盡量減少warp間的通訊和讀取shared memory的次數。
第二和第三點都可以歸結為是cuda gemm層面的優化,我們馬上來細看這兩點。
三、V2中的thread blocks排布
//gridDiminV1 //params.b=batch_size,params.h=num_heads dim3grid(params.b,params.h); //gridDiminV2 constintnum_m_block=(params.seqlen_q+Kernel_traits::kBlockM-1)/Kernel_traits::kBlockM; dim3grid(num_m_block,params.b,params.h);
這段代碼整合自flash attention github下的cutlass實現,為了方便講解做了一點改寫。
這段代碼告訴我們:
在V1中,我們是按batch_size和num_heads來劃分block的,也就是說一共有batch_size * num_heads個block,每個block負責計算O矩陣的一部分
在V2中,我們是按batch_size,num_heads和num_m_block來劃分block的,其中num_m_block可理解成是沿著Q矩陣行方向做的切分。例如Q矩陣行方向長度為seqlen_q(其實就是我們熟悉的輸入序列長度seq_len,也就是圖例中的N),我們將其劃分成num_m_block份,每份長度為kBlockM(也就是每份維護kBlockM個token)。這樣就一共有batch_size * num_heads * num_m_block個block,每個block負責計算矩陣O的一部分。
為什么相比于V1,V2在劃分thread block時,要新增Q的seq_len維度上的劃分呢?
先說結論,這樣做的目的是盡量讓SM打滿。我們知道block是會被發去SM上執行的。以1塊A100 GPU為例,它有108個SM,如果此時我們的block數量比較大(例如論文中所說>=80時),我們就認為GPU的計算資源得到了很好的利用。現在回到我們的輸入數據上來,當batch_size和num_heads都比較大時,block也比較多,此時SM利用率比較高。但是如果我們的數據seq_len比較長,此時往往對應著較小的batch_size和num_heads,這是就會有SM在空轉了。而為了解決這個問題,我們就可以引入在Q的seq_len上的劃分。
看到這里你可能還是有點懵,沒關系,我們通過圖解的方式,來一起看看V1和V2上的thread block到底長什么樣。
3.1 V1 thread block
假設batch_size = 1,num_heads = 2,我們用不同的顏色來表示不同的head。
我們知道在Multihead Attention中,各個head是可以獨立進行計算的,在計算完畢后將結果拼接起來即可。所以我們將1個head劃分給1個block,這樣就能實現block間的并行計算,如此每個block只要在計算完畢后把結果寫入自己所維護的O的對應位置即可。
而每個block內,就能執行V1中的"KV外循環,Q內循環”的過程了,這個過程是由block的再下級warp level層面進行組織,thread實行計算的。這塊我們放在第四部分中講解。
3.2 V2 thread block
現在我們繼續假設batch_size = 1,num_heads = 2。
與V1不同的是,我們在Q的seq_len維度上也做了切分,將其分成四份,即num_m_block = 4。所以現在我們共有124 = 8個block在跑。這些block之間的運算也是獨立的,因為:
head的計算是獨立的,所以紅色block和藍色block互不干擾
采用Q做外循環,KV做內循環時,行與行之間的block是獨立的,因此不同行的block互相不干擾。
每個block從Q上加載對應位置的切塊,同時從KV上加載head0的切塊,計算出自己所維護的那部分O,然后寫入O的對應位置。
在這里你可能想問,為什么只對Q的seq_len做了切分,而不對KV的seq_len做切分呢?
在V2的cutlass實現中,確實也提供了對KV的seq_len做切分的方法。但除非你認為SM真得打不滿,否則盡量不要在KV維度上做切分,因為如此一來,不同的block之間是沒法獨立計算的(比如對于O的某一行,它的各個部分來自不同的block,為了得到全局的softmax結果,這些block的結果還需要匯總做一次計算)。
3.3 seq parallel不是V2特有
如果你看過V1的代碼,你會發現,其實在V1后期的版本中,也出現了seq維度的并行:
//V1seqparallel:csrc/flash_attn/src/fmha_fwd_launch_template.h dim3grid(launch_params.params.b,launch_params.params.h,launch_params.params.num_splits); //nums_splits計算方法 //Findthenumberofsplitsthatmaximizestheoccupancy.Forexample,ifwehave //batch*n_heads=48andwehave108SMs,having2splits(efficiency=0.89)is //betterthanhaving3splits(efficiency=0.67).However,wealsodon'twanttoomany //splitsasthatwouldincurmoreHBMreads/writes. //Sowefindthebestefficiency,thenfindthesmallestnumberofsplitsthatgets95% //ofthebestefficiency. //[2022-11-25]TD:Markthisas"inline"otherwiseweget"multipledefinition"error. inlineintnum_splits_heuristic_fwd(intbatch_nheads,intnum_SMs,intctas_per_sm,intmax_splits){ floatmax_efficiency=0.f; std::vectorefficiency; efficiency.reserve(max_splits); for(intnum_splits=1;num_splits<=?max_splits;?num_splits++)?{ ????????float?n_waves?=?float(batch_nheads?*?num_splits)?/?(num_SMs?*?ctas_per_sm); ????????float?eff?=?n_waves?/?ceil(n_waves); ????????//?printf("num_splits?=?%d,?eff?=?%f ",?num_splits,?eff); ????????if?(eff?>max_efficiency){max_efficiency=eff;} efficiency.push_back(eff); } for(intnum_splits=1;num_splits<=?max_splits;?num_splits++)?{ ????????if?(efficiency[num_splits?-?1]?>0.95*max_efficiency){ //printf("num_splitschosen=%d ",num_splits); returnnum_splits; } } return1; } .... //可以發現num_splits也是由Q的seq_len維度切分來的 launch_params.params.num_splits=num_splits_heuristic_fwd( launch_params.params.b*launch_params.params.h,dprops->multiProcessorCount, ctas_per_sm, /*max_splits=*/std::min(30,(launch_params.params.seqlen_q+M-1/M)) );
上圖代碼中的num_splits也是在由Q的seq_len維度切分來的。通過這段代碼,我猜想作者在V1后期引入seq_len維度切分的原因是:V1也需要解決seq_len過長時,batch_size和num_heads較小而造成SM打不滿的問題。
num_splits_heuristic_fwd這個函數的作用概括起來就是,我先提供一連串num_splits值的備選,然后由這個函數計算出每個備選值下SM的利用率。計算完之后,我先找到最高的利用率,然后再找出滿足利用率>=0.95 * max(利用率)的那個最小的num_split值,作為最終的選擇。
細心的你此時可能已經觀察到了,雖然V1也引進過seq parallel,但是它的grid組織形式時(batch_size, num_heads, num_m_blocks),但V2的組織形式是(num_m_blocks, batch_size, num_heads),這種順序調換的意義是什么呢?
直接說結論,這樣的調換是為了提升L2 cache hit rate。大家可以看下3.2中的圖(雖然block實際執行時不一定按照圖中的序號),對于同一列的block,它們讀的是KV的相同部分,因此同一列block在讀取數據時,有很大概率可以直接從L2 cache上讀到自己要的數據(別的block之前取過的)。
3.4 FWD和BWD過程中的thread block劃分
在3.1~3.3中,我們其實給出的是FWD過程中thread block的劃分方式,我們知道V2中FWD和BWD的內外循環不一致,所以對應來說,thread block的劃分也會有所不同,我們詳細來看:
在圖中:
worker表示thread block,不同的thread block用不同顏色表示
整個大方框表示輸出矩陣O
我們先看左圖,它表示FWD下thread block的結構。每一行都有一個worker,它表示O矩陣的每一行都是由一個thread block計算出來的(假設num_heads = 1),這就對應到我們3.1~3.3中說的劃分方式。那么白色的部分表示什么呢?我們知道如果采用的是casual attention,那么有一部分是會被mask掉的,所以這里用白色來表示。但這不意味著thread block不需要加載白色部分數據對應的KV塊,只是說在計算的過程中它們會因被mask掉而免于計算(論文中的casual mask一節有提過)。
我們再看右圖,它表示BWD下thread block的結構,每一列對應一個worker,這是因為BWD中我們是KV做外循環,Q做內循環,這種情況下dK, dV都是按行累加的,而dQ是按列累加的,少數服從多數,因此這里thread_block是按的列劃分的。
四、Warp級別并行
講完了thread block,我們就可以再下一級,看到warp level級別的并行了。左圖表示V1,右圖表示V2。不管是V1還是V2,在Ampere架構下,每個block內進一步被劃分為4個warp,在Hopper架構下則是8個warp。
在左圖(V1)中,每個warp都從shared memory上讀取相同的Q塊以及自己所負責計算的KV塊。在V1中,每個warp只是計算出了列方向上的結果,這些列方向上的結果必須匯總起來,才能得到最終O矩陣行方向上的對應結果。所以每個warp需要把自己算出來的中間結果寫到shared memory上,再由一個warp(例如warp1)進行統一的整合。所以各個warp間需要通訊、需要寫中間結果,這就影響了計算效率。
在左圖(V2)中,每個warp都從shared memory上讀取相同的KV塊以及自己所負責計算的Q塊。在V2中,行方向上的計算是完全獨立的,即每個warp把自己計算出的結果寫到O的對應位置即可,warp間不需要再做通訊,通過這種方式提升了計算效率。不過這種warp并行方式在V2的BWD過程中就有缺陷了:由于bwd中dK和dV是在行方向上的AllReduce,所以這種切分方式會導致warp間需要通訊。
針對V2 warp切分影響BWD這點,作者在論文中依然給出了“BWD過程相比V1也有提升”的結論,針對這點,我在github issue上找到了一條作者的回復(在“安裝報錯”組成的issue海洋里撈出的寶貴一條):
最關鍵的可能是第1和第2點,關于第1點,我想作者應該是說,之前需要反復讀取KV的數據,現在只用反復讀取Q的數據,因此從一定程度上節省了shared memory的讀寫次數。第2點理解起來有點復雜,個人覺得是將warp處理的tile劃分得更像方形。這樣做的好處是在做casual mask的時候可以方便寫代碼大塊丟掉被mask掉的tile(見論文casual masking部分),進一步加速計算。第3點是關于一些底層的優化,就不提了。
好!關于V2我們就介紹到這了,寫這篇文章的時候,我剛粗過了一遍triton的flash attention實現,以及掃了一下cutlass實現的入口。如果后續有時間,我會出一些源碼解讀的文章(從cuda gemm -> triton gemm -> triton flash attention,看,又給自己挖了一個坑)。如果出不了,那一定不是我鴿人,那肯定是我不會(沒錯,就是這樣)。
審核編輯:黃飛
-
gpu
+關注
關注
27文章
4591瀏覽量
128144 -
并行計算
+關注
關注
0文章
27瀏覽量
9403 -
大模型
+關注
關注
2文章
2136瀏覽量
1980
原文標題:圖解大模型計算加速系列:Flash Attention V2,從原理到并行計算
文章出處:【微信號:zenRRan,微信公眾號:深度學習自然語言處理】歡迎添加關注!文章轉載請注明出處。
發布評論請先 登錄
相關推薦
評論