在深度學習中,我們經常使用 CNN 或 RNN 對序列進行編碼。現在考慮到注意力機制,想象一下將一系列標記輸入注意力機制,這樣在每個步驟中,每個標記都有自己的查詢、鍵和值。在這里,當在下一層計算令牌表示的值時,令牌可以(通過其查詢向量)參與每個其他令牌(基于它們的鍵向量進行匹配)。使用完整的查詢鍵兼容性分數集,我們可以通過在其他標記上構建適當的加權和來為每個標記計算表示。因為每個標記都關注另一個標記(不同于解碼器步驟關注編碼器步驟的情況),這種架構通常被描述為自注意力模型 (Lin等。, 2017 年, Vaswani等人。, 2017 ),以及其他地方描述的內部注意力模型 ( Cheng et al. , 2016 , Parikh et al. , 2016 , Paulus et al. , 2017 )。在本節中,我們將討論使用自注意力的序列編碼,包括使用序列順序的附加信息。
import math import torch from torch import nn from d2l import torch as d2l
import math from mxnet import autograd, np, npx from mxnet.gluon import nn from d2l import mxnet as d2l npx.set_np()
import jax from flax import linen as nn from jax import numpy as jnp from d2l import jax as d2l
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
import numpy as np import tensorflow as tf from d2l import tensorflow as d2l
11.6.1。自注意力
給定一系列輸入標記 x1,…,xn任何地方 xi∈Rd(1≤i≤n), 它的self-attention輸出一個相同長度的序列 y1,…,yn, 在哪里
(11.6.1)yi=f(xi,(x1,x1),…,(xn,xn))∈Rd
根據 (11.1.1)中attention pooling的定義。使用多頭注意力,以下代碼片段計算具有形狀(批量大小、時間步數或標記中的序列長度, d). 輸出張量具有相同的形狀。
num_hiddens, num_heads = 100, 5 attention = d2l.MultiHeadAttention(num_hiddens, num_heads, 0.5) batch_size, num_queries, valid_lens = 2, 4, torch.tensor([3, 2]) X = torch.ones((batch_size, num_queries, num_hiddens)) d2l.check_shape(attention(X, X, X, valid_lens), (batch_size, num_queries, num_hiddens))
num_hiddens, num_heads = 100, 5 attention = d2l.MultiHeadAttention(num_hiddens, num_heads, 0.5) attention.initialize() batch_size, num_queries, valid_lens = 2, 4, np.array([3, 2]) X = np.ones((batch_size, num_queries, num_hiddens)) d2l.check_shape(attention(X, X, X, valid_lens), (batch_size, num_queries, num_hiddens))
num_hiddens, num_heads = 100, 5 attention = d2l.MultiHeadAttention(num_hiddens, num_heads, 0.5) batch_size, num_queries, valid_lens = 2, 4, jnp.array([3, 2]) X = jnp.ones((batch_size, num_queries, num_hiddens)) d2l.check_shape(attention.init_with_output(d2l.get_key(), X, X, X, valid_lens, training=False)[0][0], (batch_size, num_queries, num_hiddens))
num_hiddens, num_heads = 100, 5 attention = d2l.MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens, num_hiddens, num_heads, 0.5) batch_size, num_queries, valid_lens = 2, 4, tf.constant([3, 2]) X = tf.ones((batch_size, num_queries, num_hiddens)) d2l.check_shape(attention(X, X, X, valid_lens, training=False), (batch_size, num_queries, num_hiddens))
11.6.2。比較 CNN、RNN 和自注意力
讓我們比較一下映射一系列的架構n標記到另一個等長序列,其中每個輸入或輸出標記由一個d維向量。具體來說,我們將考慮 CNN、RNN 和自注意力。我們將比較它們的計算復雜度、順序操作和最大路徑長度。請注意,順序操作會阻止并行計算,而序列位置的任意組合之間的較短路徑可以更容易地學習序列內的遠程依賴關系 (Hochreiter等人,2001 年)。
圖 11.6.1比較 CNN(省略填充標記)、RNN 和自注意力架構。
考慮一個卷積層,其內核大小為k. 我們將在后面的章節中提供有關使用 CNN 進行序列處理的更多詳細信息。現在,我們只需要知道,因為序列長度是n,輸入和輸出通道的數量都是 d, 卷積層的計算復雜度為 O(knd2). 如圖11.6.1 所示,CNN 是分層的,因此有O(1) 順序操作和最大路徑長度是 O(n/k). 例如,x1和 x5位于圖 11.6.1中內核大小為 3 的雙層 CNN 的接受域內。
在更新 RNN 的隱藏狀態時,乘以 d×d權重矩陣和d維隱藏狀態的計算復雜度為O(d2). 由于序列長度為n,循環層的計算復雜度為O(nd2). 根據 圖 11.6.1,有O(n) 不能并行化的順序操作,最大路徑長度也是O(n).
在自注意力中,查詢、鍵和值都是 n×d矩陣。考慮(11.3.6)中的縮放點積注意力,其中n×d矩陣乘以d×n矩陣,然后是輸出 n×n矩陣乘以n×d矩陣。因此,self-attention 有一個O(n2d) 計算復雜度。正如我們在圖 11.6.1中看到的 ,每個標記都通過自注意力直接連接到任何其他標記。因此,計算可以與O(1)順序操作和最大路徑長度也是O(1).
總而言之,CNN 和 self-attention 都享有并行計算,并且 self-attention 具有最短的最大路徑長度。然而,關于序列長度的二次計算復雜度使得自注意力對于非常長的序列來說非常慢。
11.6.3。位置編碼
與循環一個接一個地處理序列標記的 RNN 不同,self-attention 摒棄順序操作以支持并行計算。但是請注意,self-attention 本身并不能保持序列的順序。如果模型知道輸入序列到達的順序真的很重要,我們該怎么辦?
保留有關標記順序的信息的主要方法是將其表示為與每個標記相關聯的附加輸入的模型。這些輸入稱為位置編碼。它們可以被學習或先驗固定。我們現在描述一種基于正弦和余弦函數的固定位置編碼的簡單方案(Vaswani等人,2017 年)。
假設輸入表示 X∈Rn×d包含 d-維度嵌入n序列的標記。位置編碼輸出X+P使用位置嵌入矩陣 P∈Rn×d形狀相同,其元素在ith行和 (2j)th或者(2j+1)th專欄是
(11.6.2)pi,2j=sin?(i100002j/d),pi,2j+1=cos?(i100002j/d).
乍一看,這種三角函數設計看起來很奇怪。在解釋這個設計之前,讓我們先在下面的 PositionalEncoding類中實現它。
class PositionalEncoding(nn.Module): #@save """Positional encoding.""" def __init__(self, num_hiddens, dropout, max_len=1000): super().__init__() self.dropout = nn.Dropout(dropout) # Create a long enough P self.P = torch.zeros((1, max_len, num_hiddens)) X = torch.arange(max_len, dtype=torch.float32).reshape( -1, 1) / torch.pow(10000, torch.arange( 0, num_hiddens, 2, dtype=torch.float32) / num_hiddens) self.P[:, :, 0::2] = torch.sin(X) self.P[:, :, 1::2] = torch.cos(X) def forward(self, X): X = X + self.P[:, :X.shape[1], :].to(X.device) return self.dropout(X)
class PositionalEncoding(nn.Block): #@save """Positional encoding.""" def __init__(self, num_hiddens, dropout, max_len=1000): super().__init__() self.dropout = nn.Dropout(dropout) # Create a long enough P self.P = np.zeros((1, max_len, num_hiddens)) X = np.arange(max_len).reshape(-1, 1) / np.power( 10000, np.arange(0, num_hiddens, 2) / num_hiddens) self.P[:, :, 0::2] = np.sin(X) self.P[:, :, 1::2] = np.cos(X) def forward(self, X): X = X + self.P[:, :X.shape[1], :].as_in_ctx(X.ctx) return self.dropout(X)
class PositionalEncoding(nn.Module): #@save """Positional encoding.""" num_hiddens: int dropout: float max_len: int = 1000 def setup(self): # Create a long enough P self.P = jnp.zeros((1, self.max_len, self.num_hiddens)) X = jnp.arange(self.max_len, dtype=jnp.float32).reshape( -1, 1) / jnp.power(10000, jnp.arange( 0, self.num_hiddens, 2, dtype=jnp.float32) / self.num_hiddens) self.P = self.P.at[:, :, 0::2].set(jnp.sin(X)) self.P = self.P.at[:, :, 1::2].set(jnp.cos(X)) @nn.compact def __call__(self, X, training=False): # Flax sow API is used to capture intermediate variables self.sow('intermediates', 'P', self.P) X = X + self.P[:, :X.shape[1], :] return nn.Dropout(self.dropout)(X, deterministic=not training)
class PositionalEncoding(tf.keras.layers.Layer): #@save """Positional encoding.""" def __init__(self, num_hiddens, dropout, max_len=1000): super().__init__() self.dropout = tf.keras.layers.Dropout(dropout) # Create a long enough P self.P = np.zeros((1, max_len, num_hiddens)) X = np.arange(max_len, dtype=np.float32).reshape( -1,1)/np.power(10000, np.arange( 0, num_hiddens, 2, dtype=np.float32) / num_hiddens) self.P[:, :, 0::2] = np.sin(X) self.P[:, :, 1::2] = np.cos(X) def call(self, X, **kwargs): X = X + self.P[:, :X.shape[1], :] return self.dropout(X, **kwargs)
在位置嵌入矩陣中P,行對應于序列中的位置,列代表不同的位置編碼維度。在下面的示例中,我們可以看到6th和7th位置嵌入矩陣的列具有比 8th和9th列。之間的偏移量6th和 7th(同樣的8th和 9th) 列是由于正弦和余弦函數的交替。
encoding_dim, num_steps = 32, 60 pos_encoding = PositionalEncoding(encoding_dim, 0) X = pos_encoding(torch.zeros((1, num_steps, encoding_dim))) P = pos_encoding.P[:, :X.shape[1], :] d2l.plot(torch.arange(num_steps), P[0, :, 6:10].T, xlabel='Row (position)', figsize=(6, 2.5), legend=["Col %d" % d for d in torch.arange(6, 10)])
encoding_dim, num_steps = 32, 60 pos_encoding = PositionalEncoding(encoding_dim, 0) pos_encoding.initialize() X = pos_encoding(np.zeros((1, num_steps, encoding_dim))) P = pos_encoding.P[:, :X.shape[1], :] d2l.plot(np.arange(num_steps), P[0, :, 6:10].T, xlabel='Row (position)', figsize=(6, 2.5), legend=["Col %d" % d for d in np.arange(6, 10)])
encoding_dim, num_steps = 32, 60 pos_encoding = PositionalEncoding(encoding_dim, 0) params = pos_encoding.init(d2l.get_key(), jnp.zeros((1, num_steps, encoding_dim))) X, inter_vars = pos_encoding.apply(params, jnp.zeros((1, num_steps, encoding_dim)), mutable='intermediates') P = inter_vars['intermediates']['P'][0] # retrieve intermediate value P P = P[:, :X.shape[1], :] d2l.plot(jnp.arange(num_steps), P[0, :, 6:10].T, xlabel='Row (position)', figsize=(6, 2.5), legend=["Col %d" % d for d in jnp.arange(6, 10)])
encoding_dim, num_steps = 32, 60 pos_encoding = PositionalEncoding(encoding_dim, 0) X = pos_encoding(tf.zeros((1, num_steps, encoding_dim)), training=False) P = pos_encoding.P[:, :X.shape[1], :] d2l.plot(np.arange(num_steps), P[0, :, 6:10].T, xlabel='Row (position)', figsize=(6, 2.5), legend=["Col %d" % d for d in np.arange(6, 10)])
11.6.3.1。絕對位置信息
為了了解沿編碼維度單調降低的頻率與絕對位置信息的關系,讓我們打印出的二進制表示0,1,…,7. 正如我們所看到的,最低位、第二低位和第三低位分別在每個數字、每兩個數字和每四個數字上交替出現。
for i in range(8): print(f'{i} in binary is {i:>03b}')
0 in binary is 000 1 in binary is 001 2 in binary is 010 3 in binary is 011 4 in binary is 100 5 in binary is 101 6 in binary is 110 7 in binary is 111
for i in range(8): print(f'{i} in binary is {i:>03b}')
0 in binary is 000 1 in binary is 001 2 in binary is 010 3 in binary is 011 4 in binary is 100 5 in binary is 101 6 in binary is 110 7 in binary is 111
for i in range(8): print(f'{i} in binary is {i:>03b}')
0 in binary is 000 1 in binary is 001 2 in binary is 010 3 in binary is 011 4 in binary is 100 5 in binary is 101 6 in binary is 110 7 in binary is 111
for i in range(8): print(f'{i} in binary is {i:>03b}')
0 in binary is 000 1 in binary is 001 2 in binary is 010 3 in binary is 011 4 in binary is 100 5 in binary is 101 6 in binary is 110 7 in binary is 111
在二進制表示中,較高位的頻率比較低位低。類似地,如下面的熱圖所示,位置編碼通過使用三角函數降低編碼維度上的頻率。由于輸出是浮點數,因此這種連續表示比二進制表示更節省空間。
P = P[0, :, :].unsqueeze(0).unsqueeze(0) d2l.show_heatmaps(P, xlabel='Column (encoding dimension)', ylabel='Row (position)', figsize=(3.5, 4), cmap='Blues')
P = np.expand_dims(np.expand_dims(P[0, :, :], 0), 0) d2l.show_heatmaps(P, xlabel='Column (encoding dimension)', ylabel='Row (position)', figsize=(3.5, 4), cmap='Blues')
P = jnp.expand_dims(jnp.expand_dims(P[0, :, :], axis=0), axis=0) d2l.show_heatmaps(P, xlabel='Column (encoding dimension)', ylabel='Row (position)', figsize=(3.5, 4), cmap='Blues')
P = tf.expand_dims(tf.expand_dims(P[0, :, :], axis=0), axis=0) d2l.show_heatmaps(P, xlabel='Column (encoding dimension)', ylabel='Row (position)', figsize=(3.5, 4), cmap='Blues')
11.6.3.2。相對位置信息
除了捕獲絕對位置信息外,上述位置編碼還允許模型輕松學習相對位置的注意。這是因為對于任何固定位置偏移δ, 位置的位置編碼i+δ可以用位置的線性投影表示i.
這個投影可以用數學來解釋。表示 ωj=1/100002j/d, 任何一對 (pi,2j,pi,2j+1)在 (11.6.2)中可以線性投影到 (pi+δ,2j,pi+δ,2j+1)對于任何固定偏移 δ:
(11.6.3)[cos?(δωj)sin?(δωj)?sin?(δωj)cos?(δωj)][pi,2jpi,2j+1]=[cos?(δωj)sin?(iωj)+sin?(δωj)cos?(iωj)?sin?(δωj)sin?(iωj)+cos?(δωj)cos?(iωj)]=[sin?((i+δ)ωj)cos?((i+δ)ωj)]=[pi+δ,2jpi+δ,2j+1],
在哪里2×2投影矩陣不依賴于任何位置索引i.
11.6.4。概括
在自我關注中,查詢、鍵和值都來自同一個地方。CNN 和 self-attention 都享有并行計算,并且 self-attention 具有最短的最大路徑長度。然而,關于序列長度的二次計算復雜度使得自注意力對于非常長的序列來說非常慢。要使用序列順序信息,我們可以通過向輸入表示添加位置編碼來注入絕對或相對位置信息。
11.6.5。練習
假設我們設計了一個深度架構來表示一個序列,通過使用位置編碼堆疊自注意力層。可能是什么問題?
你能設計一個可學習的位置編碼方法嗎?
我們能否根據在自注意力中比較的查詢和鍵之間的不同偏移來分配不同的學習嵌入?提示:你可以參考相對位置嵌入 (Huang et al. , 2018 , Shaw et al. , 2018)。
-
pytorch
+關注
關注
2文章
803瀏覽量
13149
發布評論請先 登錄
相關推薦
評論