在一般的 seq2seq 問題中,如機器翻譯(第 10.5 節(jié)),輸入和輸出的長度不同且未對齊。處理這類數(shù)據(jù)的標(biāo)準(zhǔn)方法是設(shè)計一個編碼器-解碼器架構(gòu)(圖 10.6.1),它由兩個主要組件組成:一個 編碼器,它以可變長度序列作為輸入,以及一個 解碼器,作為一個條件語言模型,接收編碼輸入和目標(biāo)序列的向左上下文,并預(yù)測目標(biāo)序列中的后續(xù)標(biāo)記。
讓我們以從英語到法語的機器翻譯為例。給定一個英文輸入序列:“They”、“are”、“watching”、“.”,這種編碼器-解碼器架構(gòu)首先將可變長度輸入編碼為一個狀態(tài),然后對該狀態(tài)進(jìn)行解碼以生成翻譯后的序列,token通過標(biāo)記,作為輸出:“Ils”、“regardent”、“.”。由于編碼器-解碼器架構(gòu)構(gòu)成了后續(xù)章節(jié)中不同 seq2seq 模型的基礎(chǔ),因此本節(jié)將此架構(gòu)轉(zhuǎn)換為稍后將實現(xiàn)的接口。
import tensorflow as tf
from d2l import tensorflow as d2l
10.6.1。編碼器
在編碼器接口中,我們只是指定編碼器將可變長度序列作為輸入X
。實現(xiàn)將由繼承此基類的任何模型提供Encoder
。
class Encoder(tf.keras.layers.Layer): #@save
"""The base encoder interface for the encoder-decoder architecture."""
def __init__(self):
super().__init__()
# Later there can be additional arguments (e.g., length excluding padding)
def call(self, X, *args):
raise NotImplementedError
10.6.2。解碼器
在下面的解碼器接口中,我們添加了一個額外的init_state
方法來將編碼器輸出 ( enc_all_outputs
) 轉(zhuǎn)換為編碼狀態(tài)。請注意,此步驟可能需要額外的輸入,例如輸入的有效長度,這在 第 10.5 節(jié)中有解釋。為了逐個令牌生成可變長度序列令牌,每次解碼器都可以將輸入(例如,在先前時間步生成的令牌)和編碼狀態(tài)映射到當(dāng)前時間步的輸出令牌。
class Decoder(nn.Module): #@save
"""The base decoder interface for the encoder-decoder architecture."""
def __init__(self):
super().__init__()
# Later there can be additional arguments (e.g., length excluding padding)
def init_state(self, enc_all_outputs, *args):
raise NotImplementedError
def forward(self, X, state):
raise NotImplementedError
class Decoder(nn.Block): #@save
"""The base decoder interface for the encoder-decoder architecture."""
def __init__(self):
super().__init__()
# Later there can be additional arguments (e.g., length excluding padding)
def init_state(self, enc_all_outputs, *args):
raise NotImplementedError
def forward(self, X, state):
raise NotImplementedError
class Decoder(nn.Module): #@save
"""The base decoder interface for the encoder-decoder architecture."""
def setup(self):
raise NotImplementedError
# Later there can be additional arguments (e.g., length excluding padding)
def init_state(self, enc_all_outputs, *args):
raise NotImplementedError
def __call__(self, X, state):
raise NotImplementedError
class Decoder(tf.keras.layers.Layer): #@save
"""The base decoder interface for the encoder-decoder architecture."""
def __init__(self):
super().__init__()
# Later there can be additional arguments (e.g., length excluding padding)
def init_state(self, enc_all_outputs, *args):
raise NotImplementedError
def call(self, X, state):
raise NotImplementedError
10.6.3。將編碼器和解碼器放在一起
在前向傳播中,編碼器的輸出用于產(chǎn)生編碼狀態(tài),解碼器將進(jìn)一步使用該狀態(tài)作為其輸入之一。
class EncoderDecoder(d2l.Classifier): #@save
"""The base class for the encoder-decoder architecture."""
def __init__(self, encoder, decoder):
super().__init__()
self.encoder = encoder
self.decoder = decoder
def forward(self, enc_X, dec_X, *args):
enc_all_outputs = self.encoder(enc_X, *args)
dec_state = self.decoder.init_state(enc_all_outputs, *args)
# Return decoder output only
return self.decoder(dec_X, dec_state)[0]
class EncoderDecoder(d2l.Classifier): #@save
"""The base class for the encoder-decoder architecture."""
def __init__(self, encoder, decoder):
super().__init__()
self.encoder = encoder
self.decoder = decoder
def forward(self, enc_X, dec_X, *args):
enc_all_outputs = self.encoder(enc_X, *args)
dec_state = self.decoder.init_state(enc_all_outputs, *args)
# Return decoder output only
return self.decoder(dec_X, dec_state)[0]
class EncoderDecoder(d2l.Classifier): #@save
"""The base class for the encoder-decoder architecture."""
encoder: nn.Module
decoder: nn.Module
training: bool
def __call__(self, enc_X, dec_X, *args):
enc_all_outputs = self.encoder(enc_X, *args, training=self.training)
dec_state = self.decoder.init_state(enc_all_outputs, *args)
# Return decoder output only
return self.decoder(dec_X, dec_state, training=self.training)[0]
class EncoderDecoder(d2l.Classifier): #@save
"""The base class for the encoder-decoder architecture."""
def __init__(self, encoder, decoder):
super().__init__()
self.encoder = encoder
self.decoder = decoder
def call(self, enc_X, dec_X, *args):
enc_all_outputs = self.encoder(enc_X, *args, training=True)
dec_state = self.decoder.init_state(enc_all_outputs, *args)
# Return decoder output only
return self.decoder(dec_X, dec_state, training=True)[0]
在下一節(jié)中,我們將看到如何應(yīng)用 RNN 來設(shè)計基于這種編碼器-解碼器架構(gòu)的 seq2seq 模型。
10.6.4。概括
編碼器-解碼器架構(gòu)可以處理由可變長度序列組成的輸入和輸出,因此適用于機器翻譯等 seq2seq 問題。編碼器將可變長度序列作為輸入,并將其轉(zhuǎn)換為具有固定形狀的狀態(tài)。解碼器將固定形狀的編碼狀態(tài)映射到可變長度序列。
10.6.5。練習(xí)
-
假設(shè)我們使用神經(jīng)網(wǎng)絡(luò)來實現(xiàn)編碼器-解碼器架構(gòu)。編碼器和解碼器必須是同一類型的神經(jīng)網(wǎng)絡(luò)嗎?
-
除了機器翻譯,你能想到另一個可以應(yīng)用編碼器-解碼器架構(gòu)的應(yīng)用程序嗎?
評論
查看更多