在實(shí)踐中,給定一組相同的查詢、鍵和值,我們可能希望我們的模型結(jié)合來自同一注意機(jī)制的不同行為的知識(shí),例如捕獲各種范圍的依賴關(guān)系(例如,較短范圍與較長(zhǎng)范圍)在一個(gè)序列中。因此,這可能是有益的
允許我們的注意力機(jī)制聯(lián)合使用查詢、鍵和值的不同表示子空間。
為此,可以使用以下方式轉(zhuǎn)換查詢、鍵和值,而不是執(zhí)行單個(gè)注意力池h獨(dú)立學(xué)習(xí)線性投影。那么這些h投影查詢、鍵和值被并行輸入注意力池。到底,h 注意池的輸出與另一個(gè)學(xué)習(xí)的線性投影連接并轉(zhuǎn)換以產(chǎn)生最終輸出。這種設(shè)計(jì)稱為多頭注意力,其中每個(gè)hattention pooling outputs 是一個(gè)頭 (Vaswani et al. , 2017)。使用全連接層執(zhí)行可學(xué)習(xí)的線性變換,圖 11.5.1描述了多頭注意力。
圖 11.5.1多頭注意力,其中多個(gè)頭連接起來然后進(jìn)行線性變換。
import math import torch from torch import nn from d2l import torch as d2l
import math from mxnet import autograd, np, npx from mxnet.gluon import nn from d2l import mxnet as d2l npx.set_np()
import jax from flax import linen as nn from jax import numpy as jnp from d2l import jax as d2l
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
import tensorflow as tf from d2l import tensorflow as d2l
11.5.1。模型
在提供多頭注意力的實(shí)現(xiàn)之前,讓我們從數(shù)學(xué)上形式化這個(gè)模型。給定一個(gè)查詢 q∈Rdq, 關(guān)鍵 k∈Rdk和一個(gè)值 v∈Rdv, 每個(gè)注意力頭 hi(i=1,…,h) 被計(jì)算為
(11.5.1)hi=f(Wi(q)q,Wi(k)k,Wi(v)v)∈Rpv,
其中可學(xué)習(xí)參數(shù) Wi(q)∈Rpq×dq, Wi(k)∈Rpk×dk和 Wi(v)∈Rpv×dv, 和f是注意力集中,例如11.3 節(jié)中的附加注意力和縮放點(diǎn)積注意力。多頭注意力輸出是另一種通過可學(xué)習(xí)參數(shù)進(jìn)行的線性變換Wo∈Rpo×hpv的串聯(lián)h負(fù)責(zé)人:
(11.5.2)Wo[h1?hh]∈Rpo.
基于這種設(shè)計(jì),每個(gè)頭可能會(huì)關(guān)注輸入的不同部分。可以表達(dá)比簡(jiǎn)單加權(quán)平均更復(fù)雜的函數(shù)。
11.5.2。執(zhí)行
在我們的實(shí)現(xiàn)中,我們?yōu)槎囝^注意力的每個(gè)頭選擇縮放的點(diǎn)積注意力。為了避免計(jì)算成本和參數(shù)化成本的顯著增長(zhǎng),我們?cè)O(shè)置 pq=pk=pv=po/h. 注意h如果我們將查詢、鍵和值的線性變換的輸出數(shù)量設(shè)置為 pqh=pkh=pvh=po. 在下面的實(shí)現(xiàn)中, po通過參數(shù)指定num_hiddens。
class MultiHeadAttention(d2l.Module): #@save """Multi-head attention.""" def __init__(self, num_hiddens, num_heads, dropout, bias=False, **kwargs): super().__init__() self.num_heads = num_heads self.attention = d2l.DotProductAttention(dropout) self.W_q = nn.LazyLinear(num_hiddens, bias=bias) self.W_k = nn.LazyLinear(num_hiddens, bias=bias) self.W_v = nn.LazyLinear(num_hiddens, bias=bias) self.W_o = nn.LazyLinear(num_hiddens, bias=bias) def forward(self, queries, keys, values, valid_lens): # Shape of queries, keys, or values: # (batch_size, no. of queries or key-value pairs, num_hiddens) # Shape of valid_lens: (batch_size,) or (batch_size, no. of queries) # After transposing, shape of output queries, keys, or values: # (batch_size * num_heads, no. of queries or key-value pairs, # num_hiddens / num_heads) queries = self.transpose_qkv(self.W_q(queries)) keys = self.transpose_qkv(self.W_k(keys)) values = self.transpose_qkv(self.W_v(values)) if valid_lens is not None: # On axis 0, copy the first item (scalar or vector) for num_heads # times, then copy the next item, and so on valid_lens = torch.repeat_interleave( valid_lens, repeats=self.num_heads, dim=0) # Shape of output: (batch_size * num_heads, no. of queries, # num_hiddens / num_heads) output = self.attention(queries, keys, values, valid_lens) # Shape of output_concat: (batch_size, no. of queries, num_hiddens) output_concat = self.transpose_output(output) return self.W_o(output_concat)
class MultiHeadAttention(d2l.Module): #@save """Multi-head attention.""" def __init__(self, num_hiddens, num_heads, dropout, use_bias=False, **kwargs): super().__init__() self.num_heads = num_heads self.attention = d2l.DotProductAttention(dropout) self.W_q = nn.Dense(num_hiddens, use_bias=use_bias, flatten=False) self.W_k = nn.Dense(num_hiddens, use_bias=use_bias, flatten=False) self.W_v = nn.Dense(num_hiddens, use_bias=use_bias, flatten=False) self.W_o = nn.Dense(num_hiddens, use_bias=use_bias, flatten=False) def forward(self, queries, keys, values, valid_lens): # Shape of queries, keys, or values: # (batch_size, no. of queries or key-value pairs, num_hiddens) # Shape of valid_lens: (batch_size,) or (batch_size, no. of queries) # After transposing, shape of output queries, keys, or values: # (batch_size * num_heads, no. of queries or key-value pairs, # num_hiddens / num_heads) queries = self.transpose_qkv(self.W_q(queries)) keys = self.transpose_qkv(self.W_k(keys)) values = self.transpose_qkv(self.W_v(values)) if valid_lens is not None: # On axis 0, copy the first item (scalar or vector) for num_heads # times, then copy the next item, and so on valid_lens = valid_lens.repeat(self.num_heads, axis=0) # Shape of output: (batch_size * num_heads, no. of queries, # num_hiddens / num_heads) output = self.attention(queries, keys, values, valid_lens) # Shape of output_concat: (batch_size, no. of queries, num_hiddens) output_concat = self.transpose_output(output) return self.W_o(output_concat)
class MultiHeadAttention(nn.Module): #@save num_hiddens: int num_heads: int dropout: float bias: bool = False def setup(self): self.attention = d2l.DotProductAttention(self.dropout) self.W_q = nn.Dense(self.num_hiddens, use_bias=self.bias) self.W_k = nn.Dense(self.num_hiddens, use_bias=self.bias) self.W_v = nn.Dense(self.num_hiddens, use_bias=self.bias) self.W_o = nn.Dense(self.num_hiddens, use_bias=self.bias) @nn.compact def __call__(self, queries, keys, values, valid_lens, training=False): # Shape of queries, keys, or values: # (batch_size, no. of queries or key-value pairs, num_hiddens) # Shape of valid_lens: (batch_size,) or (batch_size, no. of queries) # After transposing, shape of output queries, keys, or values: # (batch_size * num_heads, no. of queries or key-value pairs, # num_hiddens / num_heads) queries = self.transpose_qkv(self.W_q(queries)) keys = self.transpose_qkv(self.W_k(keys)) values = self.transpose_qkv(self.W_v(values)) if valid_lens is not None: # On axis 0, copy the first item (scalar or vector) for num_heads # times, then copy the next item, and so on valid_lens = jnp.repeat(valid_lens, self.num_heads, axis=0) # Shape of output: (batch_size * num_heads, no. of queries, # num_hiddens / num_heads) output, attention_weights = self.attention( queries, keys, values, valid_lens, training=training) # Shape of output_concat: (batch_size, no. of queries, num_hiddens) output_concat = self.transpose_output(output) return self.W_o(output_concat), attention_weights
class MultiHeadAttention(d2l.Module): #@save """Multi-head attention.""" def __init__(self, key_size, query_size, value_size, num_hiddens, num_heads, dropout, bias=False, **kwargs): super().__init__() self.num_heads = num_heads self.attention = d2l.DotProductAttention(dropout) self.W_q = tf.keras.layers.Dense(num_hiddens, use_bias=bias) self.W_k = tf.keras.layers.Dense(num_hiddens, use_bias=bias) self.W_v = tf.keras.layers.Dense(num_hiddens, use_bias=bias) self.W_o = tf.keras.layers.Dense(num_hiddens, use_bias=bias) def call(self, queries, keys, values, valid_lens, **kwargs): # Shape of queries, keys, or values: # (batch_size, no. of queries or key-value pairs, num_hiddens) # Shape of valid_lens: (batch_size,) or (batch_size, no. of queries) # After transposing, shape of output queries, keys, or values: # (batch_size * num_heads, no. of queries or key-value pairs, # num_hiddens / num_heads) queries = self.transpose_qkv(self.W_q(queries)) keys = self.transpose_qkv(self.W_k(keys)) values = self.transpose_qkv(self.W_v(values)) if valid_lens is not None: # On axis 0, copy the first item (scalar or vector) for num_heads # times, then copy the next item, and so on valid_lens = tf.repeat(valid_lens, repeats=self.num_heads, axis=0) # Shape of output: (batch_size * num_heads, no. of queries, # num_hiddens / num_heads) output = self.attention(queries, keys, values, valid_lens, **kwargs) # Shape of output_concat: (batch_size, no. of queries, num_hiddens) output_concat = self.transpose_output(output) return self.W_o(output_concat)
為了允許多個(gè)頭的并行計(jì)算,上面的 MultiHeadAttention類使用了下面定義的兩種轉(zhuǎn)置方法。具體地,該transpose_output方法將方法的操作反轉(zhuǎn)transpose_qkv。
@d2l.add_to_class(MultiHeadAttention) #@save def transpose_qkv(self, X): """Transposition for parallel computation of multiple attention heads.""" # Shape of input X: (batch_size, no. of queries or key-value pairs, # num_hiddens). Shape of output X: (batch_size, no. of queries or # key-value pairs, num_heads, num_hiddens / num_heads) X = X.reshape(X.shape[0], X.shape[1], self.num_heads, -1) # Shape of output X: (batch_size, num_heads, no. of queries or key-value # pairs, num_hiddens / num_heads) X = X.permute(0, 2, 1, 3) # Shape of output: (batch_size * num_heads, no. of queries or key-value # pairs, num_hiddens / num_heads) return X.reshape(-1, X.shape[2], X.shape[3]) @d2l.add_to_class(MultiHeadAttention) #@save def transpose_output(self, X): """Reverse the operation of transpose_qkv.""" X = X.reshape(-1, self.num_heads, X.shape[1], X.shape[2]) X = X.permute(0, 2, 1, 3) return X.reshape(X.shape[0], X.shape[1], -1)
@d2l.add_to_class(MultiHeadAttention) #@save def transpose_qkv(self, X): """Transposition for parallel computation of multiple attention heads.""" # Shape of input X: (batch_size, no. of queries or key-value pairs, # num_hiddens). Shape of output X: (batch_size, no. of queries or # key-value pairs, num_heads, num_hiddens / num_heads) X = X.reshape(X.shape[0], X.shape[1], self.num_heads, -1) # Shape of output X: (batch_size, num_heads, no. of queries or key-value # pairs, num_hiddens / num_heads) X = X.transpose(0, 2, 1, 3) # Shape of output: (batch_size * num_heads, no. of queries or key-value # pairs, num_hiddens / num_heads) return X.reshape(-1, X.shape[2], X.shape[3]) @d2l.add_to_class(MultiHeadAttention) #@save def transpose_output(self, X): """Reverse the operation of transpose_qkv.""" X = X.reshape(-1, self.num_heads, X.shape[1], X.shape[2]) X = X.transpose(0, 2, 1, 3) return X.reshape(X.shape[0], X.shape[1], -1)
@d2l.add_to_class(MultiHeadAttention) #@save def transpose_qkv(self, X): """Transposition for parallel computation of multiple attention heads.""" # Shape of input X: (batch_size, no. of queries or key-value pairs, # num_hiddens). Shape of output X: (batch_size, no. of queries or # key-value pairs, num_heads, num_hiddens / num_heads) X = X.reshape((X.shape[0], X.shape[1], self.num_heads, -1)) # Shape of output X: (batch_size, num_heads, no. of queries or key-value # pairs, num_hiddens / num_heads) X = jnp.transpose(X, (0, 2, 1, 3)) # Shape of output: (batch_size * num_heads, no. of queries or key-value # pairs, num_hiddens / num_heads) return X.reshape((-1, X.shape[2], X.shape[3])) @d2l.add_to_class(MultiHeadAttention) #@save def transpose_output(self, X): """Reverse the operation of transpose_qkv.""" X = X.reshape((-1, self.num_heads, X.shape[1], X.shape[2])) X = jnp.transpose(X, (0, 2, 1, 3)) return X.reshape((X.shape[0], X.shape[1], -1))
@d2l.add_to_class(MultiHeadAttention) #@save def transpose_qkv(self, X): """Transposition for parallel computation of multiple attention heads.""" # Shape of input X: (batch_size, no. of queries or key-value pairs, # num_hiddens). Shape of output X: (batch_size, no. of queries or # key-value pairs, num_heads, num_hiddens / num_heads) X = tf.reshape(X, shape=(X.shape[0], X.shape[1], self.num_heads, -1)) # Shape of output X: (batch_size, num_heads, no. of queries or key-value # pairs, num_hiddens / num_heads) X = tf.transpose(X, perm=(0, 2, 1, 3)) # Shape of output: (batch_size * num_heads, no. of queries or key-value # pairs, num_hiddens / num_heads) return tf.reshape(X, shape=(-1, X.shape[2], X.shape[3])) @d2l.add_to_class(MultiHeadAttention) #@save def transpose_output(self, X): """Reverse the operation of transpose_qkv.""" X = tf.reshape(X, shape=(-1, self.num_heads, X.shape[1], X.shape[2])) X = tf.transpose(X, perm=(0, 2, 1, 3)) return tf.reshape(X, shape=(X.shape[0], X.shape[1], -1))
讓我們MultiHeadAttention使用一個(gè)玩具示例來測(cè)試我們實(shí)現(xiàn)的類,其中鍵和值相同。因此,多頭注意力輸出的形狀為 ( batch_size, num_queries, num_hiddens)。
num_hiddens, num_heads = 100, 5 attention = MultiHeadAttention(num_hiddens, num_heads, 0.5) batch_size, num_queries, num_kvpairs = 2, 4, 6 valid_lens = torch.tensor([3, 2]) X = torch.ones((batch_size, num_queries, num_hiddens)) Y = torch.ones((batch_size, num_kvpairs, num_hiddens)) d2l.check_shape(attention(X, Y, Y, valid_lens), (batch_size, num_queries, num_hiddens))
num_hiddens, num_heads = 100, 5 attention = MultiHeadAttention(num_hiddens, num_heads, 0.5) attention.initialize() batch_size, num_queries, num_kvpairs = 2, 4, 6 valid_lens = np.array([3, 2]) X = np.ones((batch_size, num_queries, num_hiddens)) Y = np.ones((batch_size, num_kvpairs, num_hiddens)) d2l.check_shape(attention(X, Y, Y, valid_lens), (batch_size, num_queries, num_hiddens))
num_hiddens, num_heads = 100, 5 attention = MultiHeadAttention(num_hiddens, num_heads, 0.5) batch_size, num_queries, num_kvpairs = 2, 4, 6 valid_lens = jnp.array([3, 2]) X = jnp.ones((batch_size, num_queries, num_hiddens)) Y = jnp.ones((batch_size, num_kvpairs, num_hiddens)) d2l.check_shape(attention.init_with_output(d2l.get_key(), X, Y, Y, valid_lens, training=False)[0][0], (batch_size, num_queries, num_hiddens))
num_hiddens, num_heads = 100, 5 attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens, num_hiddens, num_heads, 0.5) batch_size, num_queries, num_kvpairs = 2, 4, 6 valid_lens = tf.constant([3, 2]) X = tf.ones((batch_size, num_queries, num_hiddens)) Y = tf.ones((batch_size, num_kvpairs, num_hiddens)) d2l.check_shape(attention(X, Y, Y, valid_lens, training=False), (batch_size, num_queries, num_hiddens))
11.5.3。概括
多頭注意力通過查詢、鍵和值的不同表示子空間結(jié)合相同注意力池的知識(shí)。要并行計(jì)算多頭注意的多個(gè)頭,需要適當(dāng)?shù)膹埩坎僮鳌?/p>
11.5.4。練習(xí)
可視化本實(shí)驗(yàn)中多個(gè)頭的注意力權(quán)重。
假設(shè)我們有一個(gè)基于多頭注意力的訓(xùn)練模型,我們想要修剪最不重要的注意力頭以提高預(yù)測(cè)速度。我們?nèi)绾卧O(shè)計(jì)實(shí)驗(yàn)來衡量注意力頭的重要性?
-
pytorch
+關(guān)注
關(guān)注
2文章
803瀏覽量
13149
發(fā)布評(píng)論請(qǐng)先 登錄
相關(guān)推薦
評(píng)論