精品国产人成在线_亚洲高清无码在线观看_国产在线视频国产永久2021_国产AV综合第一页一个的一区免费影院黑人_最近中文字幕MV高清在线视频

0
  • 聊天消息
  • 系統消息
  • 評論與回復
登錄后你可以
  • 下載海量資料
  • 學習在線課程
  • 觀看技術視頻
  • 寫文章/發帖/加入社區
會員中心
創作中心

完善資料讓更多小伙伴認識你,還能領取20積分哦,立即完善>

3天內不再提示

大模型系列:Flash Attention V2整體運作流程

深度學習自然語言處理 ? 來源:大猿搬磚簡記 ? 2024-02-21 11:38 ? 次閱讀

大家好,這就為您獻上不知鴿了多久的Flash Attention V2原理解讀。

在V1的講解中,我們通過詳細的圖解和公式推導,一起學習了Flash Attention的整體運作流程。如果大家理解了V1的這塊內容,就會發現V2的原理其實非常簡單:無非是將V1計算邏輯中的內外循環相互交換,以此減少在shared memory上的讀寫次數,實現進一步提速。那當你交換了循環位置之后,在cuda層面就可以配套做一些并行計算優化。這就是V2的整體內容。

總結起來一句話:“交換了循環位置“,雖是短短一句話,卻蘊含著深深的人生哲理:只要基座選得好,回回都有迭代點,年年勇破okr!

回歸正題,本文也分兩個部分進行講解:原理與cuda層面的并行計算。

在閱讀本文前,需要先閱讀V1的講解,本文會沿用V1的表達符號及推演思路。

一、Flash Attention V2整體運作流程

1.1 V1的運作流程

我們先快速回顧一下V1的運作流程:以K,V為外循環,Q為內循環。

,遍歷:

89ee976a-cfc5-11ee-a297-92fbcf53809c.png

,遍歷:

89f50cc6-cfc5-11ee-a297-92fbcf53809c.png

為了幫助大家更好理解v1中數據塊的流轉過程,在圖中我們畫了6塊O。但實際上最終只有三塊O:。

為例,它可理解成是由經過某些處理后匯總而來的。進一步說,

我們在外循環j = 0時,先遍歷一次所有的i,在這個階段中我們產出,并將它和一些別的重要數據寫回HBM中

接下來我們進行第二次外循環,即j=1,在這個階段中我們產出。同時我們把和那些重要的數據從HBM傳入shared memory中,然后從shared memory中讀取它們,以配合產出最終的

(關于如何得到的細節我們在V1講解中詳細推導過,這里不再贅述)

在這個過程中,你是不是隱隱覺得有些別扭:

其實都和有關系,那我為什么不以Q為外循環,以KV為內循環做遍歷呢?這樣我不就能避免往shared memory上讀寫中間結果,從而一次性把乃至最終的給算出來?

同時,softmax這個操作也是在row維度上的,所以我固定Q循環KV的方式,更天然符合softmax的特性。

1.2 V2的運作流程

基于1.1中的思想,我們在V2中將原本的內外循環置換了位置(示意圖就不畫了,基本可以對比V1示意圖想象出來)。我們直接來看V2的偽代碼(如果對以下偽代碼符號表示或解讀有疑惑的朋友,最好先看一下V1的講解)。

(1)V2 FWD

8a0872e8-cfc5-11ee-a297-92fbcf53809c.png

現在,想象自己固定住了一塊Q(i),依此循環K和V的分塊(j),在這個想象下我們來解讀這份FWD為代碼。

第8行,計算分塊

第9行:

表示截止到當前分塊(包含當前分塊)為止的rowmax

表示使用當前每行最大值計算歸一化前的(我們在V1中說過,不帶波浪號的P表示(s-rowmax)/rowsum的結果,帶波浪號表示(s-rowmax))

表示截止到當前分塊(包含當前分塊為止)的rowsum

第10行:表示截止到當前分塊(包含當前分塊)為止計算出的O值。由第9和第10行知,當我們固定Q循環KV時,我們每個分塊都是用當前最新的rowmax和rowsum計算的,同理對應的也是用當前最新的rowmax和rowsum計算的。這樣當我們遍歷完所有的KV時,得到的就等于最終全局的結果。相關的證明我們在V1講解中給過,這里不再贅述,只額外提兩點:

可能在有些朋友下載的V2論文中,第十行這里O前面的因子項是,這個公式應該是錯誤的(大家動手推一下就可知,初次看到時讓我困擾了很久)。在作者個人主頁的論文鏈接中,這個typo已經被修正。

你可能已發現這個O的計算中缺少歸一化的一項,這一項其實放到了第12行做統一計算。這也是V2優化的一個點:盡量減少非矩陣的計算,因為在GPU中,非矩陣計算比矩陣計算慢16倍。

比起V1,V2中不用再存每一Q分塊對應的了。但是在BWD的過程中,我們仍需要來做的重計算,這樣才能用鏈式求導法則把dQ,dK,dV正常算出來。V2在這里用了一個很巧妙的方法,它只存一個東西(代碼13行,這樣又能進一步減少shared memory的讀寫):,這個等式中小寫的m和l可以理解成是全局的rowmax和rowsum。在接下來BWD的講解中,我們會來看到這一項的妙用。

(2)V2 BWD

一個建議:如果你在閱讀本節中覺得很困惑,一定記得先去看V1的BWD部分,有非常詳細的推導介紹。看完再來看本節就很順暢了。

8a269336-cfc5-11ee-a297-92fbcf53809c.png

我們觀察到,在V2 BWD中,內外循環的位置又換回來了,即還是KV外循環,Q內循環,這是為什么呢?

我們知道在BWD的過程中,我們主要是求(為了求它們還需要求中間結果,我們來總結一下這些梯度都需要沿著哪些方向AllReduce:

:沿著i方向做AllReduce,也就是需要每行的結果加總

:沿著i方向做AllReduce,也就是需要每行的結果加總

:沿著j方向做AllReduce,也就是需要每列的結果加總

:只與當前i,j相關

基于此,如果你還是保持Q外循環,KV外循環不變的話,這種操作其實是固定行,遍歷列的,那么在這些梯度中,只有從中受益了,K和V的梯度則進入了別扭的循環(也意味著要往shared memory上寫更多的中間結果);但如果你采用KV外循環,Q內循環,這樣K和V都受益,只有Q獨自別扭,因此是一種更好的選擇。(S和P的計算不受循環變動影響)。

前面說過,在BWD過程中讀寫我們要用全局的重新計算,計算公式如下:

但如此一來,我們就要從shared memory上同時讀取,似乎有點消耗讀寫。所以在V2中,我們只存儲,然后計算:

很容易發現這兩個計算是等價的,但V2的做法節省了讀寫量

好,現在我們就把V2相對于V1在計算原理上的改進介紹完了。接下來我們總結一下V2相對于V1所有的改進點。

二、V2相對V1的改進點

之所以把這塊內容放到“V2整體流程介紹”之后,是想讓大家在先理解V2是怎么做的基礎上,更好體會V2的優點。

總體來說,V2從以下三個方面做了改進:

置換內外循環位置,同時減少非矩陣的計算量。(這兩點我們在第一部分中已給出詳細說明)

優化Attention部分thread blocks的并行化計算,新增seq_len維度的并行,使SM的利用率盡量打滿。這其實也是內外循環置換這個總體思想配套的改進措施

優化thread blocks內部warp級別的工作模式,盡量減少warp間的通訊和讀取shared memory的次數。

第二和第三點都可以歸結為是cuda gemm層面的優化,我們馬上來細看這兩點。

三、V2中的thread blocks排布

//gridDiminV1
//params.b=batch_size,params.h=num_heads
dim3grid(params.b,params.h);

//gridDiminV2
constintnum_m_block=(params.seqlen_q+Kernel_traits::kBlockM-1)/Kernel_traits::kBlockM;
dim3grid(num_m_block,params.b,params.h);

這段代碼整合自flash attention github下的cutlass實現,為了方便講解做了一點改寫。

這段代碼告訴我們:

在V1中,我們是按batch_size和num_heads來劃分block的,也就是說一共有batch_size * num_heads個block,每個block負責計算O矩陣的一部分

在V2中,我們是按batch_size,num_heads和num_m_block來劃分block的,其中num_m_block可理解成是沿著Q矩陣行方向做的切分。例如Q矩陣行方向長度為seqlen_q(其實就是我們熟悉的輸入序列長度seq_len,也就是圖例中的N),我們將其劃分成num_m_block份,每份長度為kBlockM(也就是每份維護kBlockM個token)。這樣就一共有batch_size * num_heads * num_m_block個block,每個block負責計算矩陣O的一部分。

為什么相比于V1,V2在劃分thread block時,要新增Q的seq_len維度上的劃分呢?

先說結論,這樣做的目的是盡量讓SM打滿。我們知道block是會被發去SM上執行的。以1塊A100 GPU為例,它有108個SM,如果此時我們的block數量比較大(例如論文中所說>=80時),我們就認為GPU的計算資源得到了很好的利用。現在回到我們的輸入數據上來,當batch_size和num_heads都比較大時,block也比較多,此時SM利用率比較高。但是如果我們的數據seq_len比較長,此時往往對應著較小的batch_size和num_heads,這是就會有SM在空轉了。而為了解決這個問題,我們就可以引入在Q的seq_len上的劃分。

看到這里你可能還是有點懵,沒關系,我們通過圖解的方式,來一起看看V1和V2上的thread block到底長什么樣。

3.1 V1 thread block

8a4e28f6-cfc5-11ee-a297-92fbcf53809c.png

假設batch_size = 1,num_heads = 2,我們用不同的顏色來表示不同的head。

我們知道在Multihead Attention中,各個head是可以獨立進行計算的,在計算完畢后將結果拼接起來即可。所以我們將1個head劃分給1個block,這樣就能實現block間的并行計算,如此每個block只要在計算完畢后把結果寫入自己所維護的O的對應位置即可。

而每個block內,就能執行V1中的"KV外循環,Q內循環”的過程了,這個過程是由block的再下級warp level層面進行組織,thread實行計算的。這塊我們放在第四部分中講解。

3.2 V2 thread block

8a5b8564-cfc5-11ee-a297-92fbcf53809c.png

現在我們繼續假設batch_size = 1,num_heads = 2。

與V1不同的是,我們在Q的seq_len維度上也做了切分,將其分成四份,即num_m_block = 4。所以現在我們共有124 = 8個block在跑。這些block之間的運算也是獨立的,因為:

head的計算是獨立的,所以紅色block和藍色block互不干擾

采用Q做外循環,KV做內循環時,行與行之間的block是獨立的,因此不同行的block互相不干擾。

每個block從Q上加載對應位置的切塊,同時從KV上加載head0的切塊,計算出自己所維護的那部分O,然后寫入O的對應位置。

在這里你可能想問,為什么只對Q的seq_len做了切分,而不對KV的seq_len做切分呢?

在V2的cutlass實現中,確實也提供了對KV的seq_len做切分的方法。但除非你認為SM真得打不滿,否則盡量不要在KV維度上做切分,因為如此一來,不同的block之間是沒法獨立計算的(比如對于O的某一行,它的各個部分來自不同的block,為了得到全局的softmax結果,這些block的結果還需要匯總做一次計算)。

3.3 seq parallel不是V2特有

如果你看過V1的代碼,你會發現,其實在V1后期的版本中,也出現了seq維度的并行:

//V1seqparallel:csrc/flash_attn/src/fmha_fwd_launch_template.h
dim3grid(launch_params.params.b,launch_params.params.h,launch_params.params.num_splits);

//nums_splits計算方法
//Findthenumberofsplitsthatmaximizestheoccupancy.Forexample,ifwehave
//batch*n_heads=48andwehave108SMs,having2splits(efficiency=0.89)is
//betterthanhaving3splits(efficiency=0.67).However,wealsodon'twanttoomany
//splitsasthatwouldincurmoreHBMreads/writes.
//Sowefindthebestefficiency,thenfindthesmallestnumberofsplitsthatgets95%
//ofthebestefficiency.
//[2022-11-25]TD:Markthisas"inline"otherwiseweget"multipledefinition"error.
inlineintnum_splits_heuristic_fwd(intbatch_nheads,intnum_SMs,intctas_per_sm,intmax_splits){
floatmax_efficiency=0.f;
std::vectorefficiency;
efficiency.reserve(max_splits);
for(intnum_splits=1;num_splits<=?max_splits;?num_splits++)?{
????????float?n_waves?=?float(batch_nheads?*?num_splits)?/?(num_SMs?*?ctas_per_sm);
????????float?eff?=?n_waves?/?ceil(n_waves);
????????//?printf("num_splits?=?%d,?eff?=?%f
",?num_splits,?eff);
????????if?(eff?>max_efficiency){max_efficiency=eff;}
efficiency.push_back(eff);
}
for(intnum_splits=1;num_splits<=?max_splits;?num_splits++)?{
????????if?(efficiency[num_splits?-?1]?>0.95*max_efficiency){
//printf("num_splitschosen=%d
",num_splits);
returnnum_splits;
}
}
return1;
}

....
//可以發現num_splits也是由Q的seq_len維度切分來的
launch_params.params.num_splits=num_splits_heuristic_fwd(
launch_params.params.b*launch_params.params.h,dprops->multiProcessorCount,
ctas_per_sm,
/*max_splits=*/std::min(30,(launch_params.params.seqlen_q+M-1/M))
);

上圖代碼中的num_splits也是在由Q的seq_len維度切分來的。通過這段代碼,我猜想作者在V1后期引入seq_len維度切分的原因是:V1也需要解決seq_len過長時,batch_size和num_heads較小而造成SM打不滿的問題。

num_splits_heuristic_fwd這個函數的作用概括起來就是,我先提供一連串num_splits值的備選,然后由這個函數計算出每個備選值下SM的利用率。計算完之后,我先找到最高的利用率,然后再找出滿足利用率>=0.95 * max(利用率)的那個最小的num_split值,作為最終的選擇。

細心的你此時可能已經觀察到了,雖然V1也引進過seq parallel,但是它的grid組織形式時(batch_size, num_heads, num_m_blocks),但V2的組織形式是(num_m_blocks, batch_size, num_heads),這種順序調換的意義是什么呢?

直接說結論,這樣的調換是為了提升L2 cache hit rate。大家可以看下3.2中的圖(雖然block實際執行時不一定按照圖中的序號),對于同一列的block,它們讀的是KV的相同部分,因此同一列block在讀取數據時,有很大概率可以直接從L2 cache上讀到自己要的數據(別的block之前取過的)。

3.4 FWD和BWD過程中的thread block劃分

在3.1~3.3中,我們其實給出的是FWD過程中thread block的劃分方式,我們知道V2中FWD和BWD的內外循環不一致,所以對應來說,thread block的劃分也會有所不同,我們詳細來看:

8a5fe514-cfc5-11ee-a297-92fbcf53809c.png

在圖中:

worker表示thread block,不同的thread block用不同顏色表示

整個大方框表示輸出矩陣O

我們先看左圖,它表示FWD下thread block的結構。每一行都有一個worker,它表示O矩陣的每一行都是由一個thread block計算出來的(假設num_heads = 1),這就對應到我們3.1~3.3中說的劃分方式。那么白色的部分表示什么呢?我們知道如果采用的是casual attention,那么有一部分是會被mask掉的,所以這里用白色來表示。但這不意味著thread block不需要加載白色部分數據對應的KV塊,只是說在計算的過程中它們會因被mask掉而免于計算(論文中的casual mask一節有提過)。

我們再看右圖,它表示BWD下thread block的結構,每一列對應一個worker,這是因為BWD中我們是KV做外循環,Q做內循環,這種情況下dK, dV都是按行累加的,而dQ是按列累加的,少數服從多數,因此這里thread_block是按的列劃分的。

四、Warp級別并行

8a6a9bda-cfc5-11ee-a297-92fbcf53809c.png

講完了thread block,我們就可以再下一級,看到warp level級別的并行了。左圖表示V1,右圖表示V2。不管是V1還是V2,在Ampere架構下,每個block內進一步被劃分為4個warp,在Hopper架構下則是8個warp。

在左圖(V1)中,每個warp都從shared memory上讀取相同的Q塊以及自己所負責計算的KV塊。在V1中,每個warp只是計算出了列方向上的結果,這些列方向上的結果必須匯總起來,才能得到最終O矩陣行方向上的對應結果。所以每個warp需要把自己算出來的中間結果寫到shared memory上,再由一個warp(例如warp1)進行統一的整合。所以各個warp間需要通訊、需要寫中間結果,這就影響了計算效率。

在左圖(V2)中,每個warp都從shared memory上讀取相同的KV塊以及自己所負責計算的Q塊。在V2中,行方向上的計算是完全獨立的,即每個warp把自己計算出的結果寫到O的對應位置即可,warp間不需要再做通訊,通過這種方式提升了計算效率。不過這種warp并行方式在V2的BWD過程中就有缺陷了:由于bwd中dK和dV是在行方向上的AllReduce,所以這種切分方式會導致warp間需要通訊。

針對V2 warp切分影響BWD這點,作者在論文中依然給出了“BWD過程相比V1也有提升”的結論,針對這點,我在github issue上找到了一條作者的回復(在“安裝報錯”組成的issue海洋里撈出的寶貴一條):

8a815a64-cfc5-11ee-a297-92fbcf53809c.png

最關鍵的可能是第1和第2點,關于第1點,我想作者應該是說,之前需要反復讀取KV的數據,現在只用反復讀取Q的數據,因此從一定程度上節省了shared memory的讀寫次數。第2點理解起來有點復雜,個人覺得是將warp處理的tile劃分得更像方形。這樣做的好處是在做casual mask的時候可以方便寫代碼大塊丟掉被mask掉的tile(見論文casual masking部分),進一步加速計算。第3點是關于一些底層的優化,就不提了。

好!關于V2我們就介紹到這了,寫這篇文章的時候,我剛粗過了一遍triton的flash attention實現,以及掃了一下cutlass實現的入口。如果后續有時間,我會出一些源碼解讀的文章(從cuda gemm -> triton gemm -> triton flash attention,看,又給自己挖了一個坑)。如果出不了,那一定不是我鴿人,那肯定是我不會(沒錯,就是這樣)。

審核編輯:黃飛

聲明:本文內容及配圖由入駐作者撰寫或者入駐合作網站授權轉載。文章觀點僅代表作者本人,不代表電子發燒友網立場。文章及其配圖僅供工程師學習之用,如有內容侵權或者其他違規問題,請聯系本站處理。 舉報投訴
  • gpu
    gpu
    +關注

    關注

    27

    文章

    4591

    瀏覽量

    128144
  • 并行計算
    +關注

    關注

    0

    文章

    27

    瀏覽量

    9403
  • 大模型
    +關注

    關注

    2

    文章

    2136

    瀏覽量

    1980

原文標題:圖解大模型計算加速系列:Flash Attention V2,從原理到并行計算

文章出處:【微信號:zenRRan,微信公眾號:深度學習自然語言處理】歡迎添加關注!文章轉載請注明出處。

收藏 人收藏

    評論

    相關推薦

    我可以使用ST-LINK/V2為STM8S系列編程閃存嗎?

    大家好,我可以使用''ST-LINK / V2''為STM8S系列編程閃存嗎? #ST-LINK / V2以上來自于谷歌翻譯以下為原文 Hi to all,Can i use ''ST-LINK/
    發表于 09-17 12:53

    ST-LINK/V2 ST-LINK/V2ST-LINK / V2在調試器/編程器STM8和STM32

    電子發燒友網為你提供(ti)ST-LINK/V2相關產品參數、數據手冊,更有ST-LINK/V2的引腳圖、接線圖、封裝手冊、中文資料、英文資料,ST-LINK/V2真值表,ST-LINK/V2
    發表于 05-21 00:05

    Kinect v2(Microsoft Kinect for Windows v2 )配置移動電源解決方案

    Kinect v2配置移動電源解決方案Kinect v2如果用于移動機器人上(也可以是其他應用場景),為方便有效地展開后續工作,為其配置移動電源是十分必要的。一、選擇移動電源Kinect v2原裝
    發表于 01-05 14:51 ?5次下載
    Kinect <b class='flag-5'>v2</b>(Microsoft Kinect for Windows <b class='flag-5'>v2</b> )配置移動電源解決方案

    Kinect v2(Microsoft Kinect for Windows v2 )配置移動電源解決方案

    Kinect v2(Microsoft Kinect for Windows v2 )配置移動電源解決方案
    發表于 01-05 14:53 ?0次下載
    Kinect <b class='flag-5'>v2</b>(Microsoft Kinect for Windows <b class='flag-5'>v2</b> )配置移動電源解決方案

    學習V2更新板開源分享

    電子發燒友網站提供《學習V2更新板開源分享.zip》資料免費下載
    發表于 07-26 09:38 ?0次下載
    學習<b class='flag-5'>V2</b>更新板開源分享

    LED面板V2開源分享

    電子發燒友網站提供《LED面板V2開源分享.zip》資料免費下載
    發表于 08-02 09:37 ?2次下載
    LED面板<b class='flag-5'>V2</b>開源分享

    智能BMS V2開源設計

    電子發燒友網站提供《智能BMS V2開源設計.zip》資料免費下載
    發表于 08-08 11:38 ?17次下載
    智能BMS <b class='flag-5'>V2</b>開源設計

    Leaphy Motor shield V2開源

    電子發燒友網站提供《Leaphy Motor shield V2開源.zip》資料免費下載
    發表于 08-22 15:41 ?0次下載
    Leaphy Motor shield <b class='flag-5'>V2</b>開源

    NodeMCU V2 Amica V3 Lolin的盾牌

    電子發燒友網站提供《NodeMCU V2 Amica V3 Lolin的盾牌.zip》資料免費下載
    發表于 08-24 10:05 ?2次下載
    NodeMCU <b class='flag-5'>V2</b> Amica <b class='flag-5'>V</b>3 Lolin的盾牌

    V2 控制器的操作原理

    V2 控制器的操作原理
    發表于 11-14 21:08 ?0次下載
    <b class='flag-5'>V2</b> 控制器的操作原理

    智能電源模塊,Motion SPM? 55 V2 系列用戶指南

    智能電源模塊,Motion SPM? 55 V2 系列用戶指南
    發表于 11-15 20:07 ?0次下載
    智能電源模塊,Motion SPM? 55 <b class='flag-5'>V2</b> <b class='flag-5'>系列</b>用戶指南

    京東方柔性OLED屏幕賦能榮耀Magic V2系列及MagicPad平板

    7月12日,在榮耀舉辦的全場景新品發布會上,重磅推出了“革命性”折疊旗艦Magic V2系列以及首款MagicPad平板產品,榮耀Magic V2系列搭載BOE(京東方)全新一代柔性O
    的頭像 發表于 07-14 11:16 ?1735次閱讀

    Cadence 與 Arm 合作,成功利用 Cadence AI 驅動流程加速 Neoverse V2 數據中心設計

    內容提要 ● Cadence 優化了其 AI 驅動的 RTL-to-GDS 數字流程,并為 Arm Neoverse V2 平臺提供了相應的 5nm 和 3nm 快速應用工具包(RAK),助力設計人
    的頭像 發表于 09-05 12:10 ?3453次閱讀

    國民技術DS_N32WB452系列數據手冊V2

    國民技術DS_N32WB452系列數據手冊V2
    發表于 10-18 16:13 ?1次下載

    產品簡介 | RZ/V2系列MPU

    產品簡介 | RZ/V2系列MPU
    的頭像 發表于 05-08 08:06 ?294次閱讀
    產品簡介 | RZ/<b class='flag-5'>V2</b><b class='flag-5'>系列</b>MPU