循環(huán)神經(jīng)網(wǎng)絡(luò)(RNN)是用于自然語(yǔ)言建模的主流架構(gòu),通常,RNN按順序讀取輸入的token,再輸出每個(gè)token的分布式表示。通過(guò)利用相同的函數(shù)來(lái)循環(huán)更新隱藏狀態(tài),RNN的計(jì)算成本將保持不變。雖然這一特點(diǎn)對(duì)于某些應(yīng)用來(lái)說(shuō)很常見,但在語(yǔ)言處理過(guò)程中,并不是所有token都同等重要,關(guān)鍵要學(xué)會(huì)取舍。例如,在問(wèn)答題中,只對(duì)重要部分進(jìn)行大量計(jì)算,不相關(guān)部分分配較少的計(jì)算才是有效的方法。
雖然有注意力模型和LSTM等方法提高計(jì)算效率或挑選重要任務(wù),但它們的表現(xiàn)都不夠好。在本篇論文中,研究人員提出了“Skim-RNN”的概念,用很少的時(shí)間進(jìn)行快速閱讀,不影響讀者的主要目標(biāo)。
Skim-RNN的構(gòu)成
受人類快速閱讀原理的啟發(fā),Skim-RNN的結(jié)構(gòu)由兩個(gè)RNN模型構(gòu)成:較大的默認(rèn)RNN模型d和較小的RNN模型d’。d和d’是用戶定義的超參數(shù),并且d’<< d。
如果當(dāng)前token比較重要,Skim-RNN就會(huì)使用大的RNN;如果不重要,它就會(huì)轉(zhuǎn)向使用小的RNN。由于小RNN比大RNN需要的浮點(diǎn)運(yùn)算次數(shù)更少,所以該模型比單獨(dú)使用大RNN所得結(jié)果更快,甚至更好。
推理過(guò)程
在每一步驟t中,Skim-RNN將輸入的Xt∈Rd和之前的隱藏狀態(tài)ht-1∈Rd作為其參數(shù),輸出新的狀態(tài)ht。k代表每一步做出harddecision的次數(shù)。在Skim-RNN中,不論是完全閱讀或跳過(guò),k=2。
研究人員使用多項(xiàng)隨機(jī)變量Qt對(duì)選擇概率分布Pt的決策過(guò)程進(jìn)行建模。Pt表示為:
這里,W∈Rk×2d,b∈Rk。
接下來(lái)我們定義隨機(jī)變量Qt,通過(guò)從概率分布Pt對(duì)Qt進(jìn)行采樣:
如果Qt=1,那么該模型與標(biāo)準(zhǔn)RNN模型一樣。如果Qt=2,那么模型選用了較小RNN模型以獲取較小的隱藏狀態(tài)。即:
其中f是帶有d輸出的完全RNN,而f'是帶有d'輸出的小RNN,d'<< d。
實(shí)驗(yàn)結(jié)果
研究人員在七組數(shù)據(jù)集上對(duì)Skim-RNN進(jìn)行測(cè)試,包括分類測(cè)試和問(wèn)答題兩種形式,目的是為了檢驗(yàn)?zāi)P偷臏?zhǔn)確度和浮點(diǎn)運(yùn)算減少率(Flop-R)。
文本分類
在這項(xiàng)任務(wù)中,輸入的是單詞序列,輸出的是分類概率的向量。最終,下表顯示出Skim-RNN模型與LSTM、LSTM-Jump的精確度和計(jì)算成本對(duì)比。
以SST、爛番茄、IMDB和AGnews四個(gè)網(wǎng)站為例進(jìn)行本文分類,在標(biāo)準(zhǔn)LSTM、Skim-RNN、LSTM-Jump和最先進(jìn)的模型(SOTA)上進(jìn)行對(duì)比
改變較小隱藏狀態(tài)的尺寸的影響,以及參數(shù)γ對(duì)精確度和計(jì)算成本的影響(默認(rèn)d=100,d'=10,γ=0.02)
下圖是IMDB數(shù)據(jù)集中的一個(gè)例子,其中Skim-RNN的參數(shù)為:d=200,d'=10,γ=0.01,最終將本段文字正確分類的概率為92%。
其中黑色的字被略過(guò)(用小LSTM模型,d'=10),藍(lán)色的字表示被閱讀(用較大的LSTM模型,d=200)
和預(yù)期的一樣,模型忽略了類似介詞等不重要的詞語(yǔ),而注意到了非常重要的單詞,例如“喜歡”、“可怕”、“討厭的”。
回答問(wèn)題
這項(xiàng)任務(wù)的目的是在給定段落中找到答案的位置。為了檢測(cè)Skim-RNN的準(zhǔn)確度,研究人員建立了兩個(gè)不同的模型:LSTM+注意力和BiDAF。結(jié)果如下所示:
F1和EM值可表明Skim-RNN的準(zhǔn)確度。最終發(fā)現(xiàn),速讀(skimming)模型的F1分?jǐn)?shù)比默認(rèn)沒(méi)有速讀(non-skimming)的模型相同甚至更高,并且計(jì)算成本消耗得更少(大于1.4倍)。
LSTM+注意力模型中,不同層的LSTM速度率(skimming rate)隨γ的變化而變化的情況
LSTM+注意力模型的F1分?jǐn)?shù)。計(jì)算成本越大,模型表現(xiàn)得越好。在同樣的計(jì)算成本下,Skim LSTM(紅色)比標(biāo)準(zhǔn)LSTM(藍(lán)色)的表現(xiàn)要好。另外,Skim-LSTM的F1分?jǐn)?shù)在不同參數(shù)和計(jì)算成本下都更穩(wěn)定
F1分?jǐn)?shù)與Flop-R之間的關(guān)系
下圖是模型回答問(wèn)題的一個(gè)例子,問(wèn)題為:最大的建筑項(xiàng)目(construction project)也稱作什么?(正確答案:megaprojects)
模型給出的答案:megaprojects。
紅色代表閱讀,白色代表略過(guò)
運(yùn)行時(shí)間
上圖顯示了與標(biāo)準(zhǔn)LSTM相比,Skim-LSTM的相對(duì)速度增益的隱藏狀態(tài)有不同大小和速度速率。在這一過(guò)程中,研究人員使用的是NumPy,并在CPU的單個(gè)線程上進(jìn)行推論。
可以看到,實(shí)際增益(實(shí)線)和理論增益(虛線)之間的差距無(wú)法避免。隨著隱藏狀態(tài)增大,這一差距會(huì)減小。所以對(duì)于更大的隱藏狀態(tài),Skim-RNN的表現(xiàn)會(huì)更好。
結(jié)語(yǔ)
本次研究表明,新型循環(huán)神經(jīng)網(wǎng)絡(luò)Skim-RNN可以根據(jù)輸入的重要性決定使用大的RNN還是小的RNN,同時(shí)計(jì)算成本比RNN更低,準(zhǔn)確度與標(biāo)準(zhǔn)LSTM和LSTM-Jump相比類似甚至更好。由于Skim-RNN與RNN具有相同的輸入輸出接口,因此可以輕松替換現(xiàn)有應(yīng)用中的RNN。
所以,這樣工作適用于需要更高隱藏狀態(tài)的應(yīng)用,比如理解視頻,同時(shí)還可以利用小RNN做不同程度的略讀。
-
循環(huán)神經(jīng)網(wǎng)絡(luò)
+關(guān)注
關(guān)注
0文章
31瀏覽量
2957
原文標(biāo)題:用Skim-RNN顯著降低計(jì)算成本,實(shí)現(xiàn)“速讀”
文章出處:【微信號(hào):jqr_AI,微信公眾號(hào):論智】歡迎添加關(guān)注!文章轉(zhuǎn)載請(qǐng)注明出處。
發(fā)布評(píng)論請(qǐng)先 登錄
相關(guān)推薦
評(píng)論