前言
標準的Transformer Block并不簡介,每個block由attention, MLP, skip connection, normalization各子模塊構成。一些看似微小的修改可能導致模型訓練速度下降,甚至導致模型無法收斂。
在本篇工作中,我們探索了Transformer Block精簡的方式。結合了信號傳播理論以及一些經驗性的觀察,我們在不損失訓練速度的前提下,移除了skip connection, out project, value project, normalization操作 以及串行組織block的形式。在Decoder-only和Encoder-only兩類模型上,我們減少了15%可訓練參數,并提高了15%的訓練速度。
官方倉庫:
bobby-he/simplified_transformers
論文:Simplifying Transformer Blocks.
一些標記注解:
?
每個transformer block如上述公式組成,每個子模塊都配備了一個系數,這個后續會使用到
Removing Skip Connection
作者先前的一項工作Deep Transformers without Shortcuts: Modifying Self-attention for Faithful Signal Propagation 刪除了殘差連接,提出的操作Value-SkipInit,將自注意力相關操作修改為:
其中I代表的是一個Identity操作,A(X)表示原始注意力操作。這兩個操作各自有一個可訓練標量 和 ,初始化為 , 。
這個設計的insight是每個token在訓練前期更多的是關注自身相關性,類似的如Pre-LN操作,在Batch Normalization Biases Residual Blocks Towards the Identity Function in Deep Networks這項工作發現,Pre-LN相當于把 skip-branch 權重提高,降低residual-branch權重,以在較深的神經網絡里仍然有良好的信號傳播。
而The Shaped Transformer: Attention Models in the Infinite Depth-and-Width Limit 該工作里提出了Shape Attention,也是收到信號傳播理論的啟發,將注意力公式更改為:
相比之下多了一個C矩陣,這是個常量矩陣(論文稱其為centering matrix),不參與訓練。他的值被設置為當 querykey dot 為0時候,A(x)的值,那么我們回去看A(x)公式,就剩一個mask值,因此代碼里是這么寫的:
?
?
#?Centered?attention,?from?https://arxiv.org/abs/2306.17759 ????????uniform_causal_attn_mat?=?torch.ones( ????????????(max_positions,?max_positions),?dtype=torch.float32 ????????)?/?torch.arange(1,?max_positions?+?1).view(-1,?1) ????????self.register_buffer( ????????????"uniform_causal_attn_mat", ????????????torch.tril( ????????????????uniform_causal_attn_mat, ????????????).view(1,?1,?max_positions,?max_positions), ????????????persistent=False, ????????)
?
?
對于CausalLM來說,MASK是個下三角矩陣,形狀為(S, S)的矩陣,第i行,只有前i個位置有值,經過softmax后,1.0概率被平分到有值的位置,這就是為什么它要做一個 ones / arange 的操作,一段示例代碼為:
?
?
import?torch max_positions?=?32 mask?=?torch.tril(torch.ones(max_positions,?max_positions))?+?torch.triu(torch.ones(max_positions,?max_positions),?1)?*?-65536 print(torch.softmax(mask,?-1)) tensor([[1.0000,?0.0000,?0.0000,??...,?0.0000,?0.0000,?0.0000], ????????[0.5000,?0.5000,?0.0000,??...,?0.0000,?0.0000,?0.0000], ????????[0.3333,?0.3333,?0.3333,??...,?0.0000,?0.0000,?0.0000], ????????..., ????????[0.0333,?0.0333,?0.0333,??...,?0.0333,?0.0000,?0.0000], ????????[0.0323,?0.0323,?0.0323,??...,?0.0323,?0.0323,?0.0000], ????????[0.0312,?0.0312,?0.0312,??...,?0.0312,?0.0312,?0.0312]])
?
?
而新的可訓練標量 = ,以保證初始化時,
其中這些可訓練標量如果改成headwise,即每個注意力頭獨立,則性能有部分提升。當然作者還是強調其中的一個重要的點是,顯式的將MLP Block的系數降低:
論文里針對18層Transformer,設置為0.1
Recovering Training Speed
在引入shape attention并移除殘差連接后,訓是沒問題了,但是會導致收斂變慢:
經過前面的修改,那么對于Attention模塊里,在訓練初期其實就簡化成X和Vproject矩陣和OutProject矩陣做矩陣乘操作。
眾所周知,這種沒有殘差連接的網絡訓練是要比帶殘差結構的網絡要慢的。我們從別的工作也可以得知,Pre-LN操作,是會降低殘差分支的占比系數,相當于降低了學習率,也縮減了線性層里參數更新的scale
X matmul W,那么計算X的梯度公式有一項就是W嘛
這促使我們開始引入重參數化操作思考V矩陣和OutProject矩陣
作者針對Vproject和Outproject兩個矩陣乘操作,給殘差分支和跳躍分支各引入一個可訓練參數 , ,通過實驗發現,大部分層最終系數比值 收斂到了0
這意味著 和 兩個矩陣是一個Identity矩陣,因此作者將這兩個參數移除掉,并稱為Simplified Attention Sub-block (SAS),使用SAS比原始Pre-LN block收斂更快了:
REMOVING THE MLP SUB-BLOCK SKIP CONNECTION
在這部分實驗里,作者把目光投向了GPT-J里提出的Parallel Block,其移除了MLP的殘差分支,保留了另外一個殘差分支:
對應公式為:
作者直接將SAS Block進行替換,得到Parallel形式的 SAS-P Block。我們比較下和原始串行的實現:
?
在訓練初期,Attention部分是Identity輸出,因此兩種形式的SAS Block在訓練初期是等價的。
REMOVING NORMALISATION LAYERS
最后作者嘗試將Norm層給移除,得到
作者的idea來自于,先前PreLN的作用(如把 skip-branch 權重提高,降低residual-branch權重)已經通過前面的一系列修改實現了,因此可以直接刪除Norm層
當然還是得看實驗效果,回到這張圖,可以看到移除了Norm對收斂還是有一定影響的。作者猜測在信號傳播理論范圍之外,Norm層能加速訓練收斂,如Scaling Vision Transformers to 22 Billion Parameters
引入了更多LayerNorm層,將ViT縮放至22B參數量上
因此作者還是主張保留PreLN結構:
最后實驗
作者也補充了一些訓練速度benchmark,模型準確率,以及收斂趨勢的實驗:
總結
作者對Transformer Block移除了各種參數,減少了15%參數量,提高了15%的訓練速度,各個環節都有做充分的實驗,但一些經驗性得到的結論也并沒有直接回答一些問題(如LN為什么影響收斂速度)。
實驗規模并不大,而標準的TransformerBlock還是在各個Scale里得到廣泛驗證的,期待有人進一步試驗
你說的對,但我還是套LLAMA結構
審核編輯:黃飛
?
評論
查看更多