在深度學習中,我們經常使用 CNN 或 RNN 對序列進行編碼。現在考慮到注意力機制,想象一下將一系列標記輸入注意力機制,這樣在每個步驟中,每個標記都有自己的查詢、鍵和值。在這里,當在下一層計算令牌表示的值時,令牌可以(通過其查詢向量)參與每個其他令牌(基于它們的鍵向量進行匹配)。使用完整的查詢鍵兼容性分數集,我們可以通過在其他標記上構建適當的加權和來為每個標記計算表示。因為每個標記都關注另一個標記(不同于解碼器步驟關注編碼器步驟的情況),這種架構通常被描述為自注意力模型 (Lin等。, 2017 年, Vaswani等人。, 2017 ),以及其他地方描述的內部注意力模型 ( Cheng et al. , 2016 , Parikh et al. , 2016 , Paulus et al. , 2017 )。在本節中,我們將討論使用自注意力的序列編碼,包括使用序列順序的附加信息。
11.6.1。自注意力
給定一系列輸入標記 x1,…,xn任何地方 xi∈Rd(1≤i≤n), 它的self-attention輸出一個相同長度的序列 y1,…,yn, 在哪里
根據 (11.1.1)中attention pooling的定義。使用多頭注意力,以下代碼片段計算具有形狀(批量大小、時間步數或標記中的序列長度, d). 輸出張量具有相同的形狀。
num_hiddens, num_heads = 100, 5
attention = d2l.MultiHeadAttention(num_hiddens, num_heads, 0.5)
batch_size, num_queries, valid_lens = 2, 4, torch.tensor([3, 2])
X = torch.ones((batch_size, num_queries, num_hiddens))
d2l.check_shape(attention(X, X, X, valid_lens),
(batch_size, num_queries, num_hiddens))
num_hiddens, num_heads = 100, 5
attention = d2l.MultiHeadAttention(num_hiddens, num_heads, 0.5)
attention.initialize()
batch_size, num_queries, valid_lens = 2, 4, np.array([3, 2])
X = np.ones((batch_size, num_queries, num_hiddens))
d2l.check_shape(attention(X, X, X, valid_lens),
(batch_size, num_queries, num_hiddens))
num_hiddens, num_heads = 100, 5
attention = d2l.MultiHeadAttention(num_hiddens, num_heads, 0.5)
batch_size, num_queries, valid_lens = 2, 4, jnp.array([3, 2])
X = jnp.ones((batch_size, num_queries, num_hiddens))
d2l.check_shape(attention.init_with_output(d2l.get_key(), X, X, X, valid_lens,
training=False)[0][0],
(batch_size, num_queries, num_hiddens))
num_hiddens, num_heads = 100, 5
attention = d2l.MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens,
num_hiddens, num_heads, 0.5)
batch_size, num_queries, valid_lens = 2, 4, tf.constant([3, 2])
X = tf.ones((batch_size, num_queries, num_hiddens))
d2l.check_shape(attention(X, X, X, valid_lens, training=False),
(batch_size, num_queries, num_hiddens))
11.6.2。比較 CNN、RNN 和自注意力
讓我們比較一下映射一系列的架構n標記到另一個等長序列,其中每個輸入或輸出標記由一個d維向量。具體來說,我們將考慮 CNN、RNN 和自注意力。我們將比較它們的計算復雜度、順序操作和最大路徑長度。請注意,順序操作會阻止并行計算,而序列位置的任意組合之間的較短路徑可以更容易地學習序列內的遠程依賴關系 (Hochreiter等人,2001 年)。
考慮一個卷積層,其內核大小為k. 我們將在后面的章節中提供有關使用 CNN 進行序列處理的更多詳細信息。現在,我們只需要知道,因為序列長度是n,輸入和輸出通道的數量都是 d, 卷積層的計算復雜度為 O(knd2). 如圖11.6.1 所示,CNN 是分層的,因此有O(1) 順序操作和最大路徑長度是
評論
查看更多