separate inputs into encoder & decoder inputs
This commit is contained in:
@@ -130,7 +130,7 @@ class PreTrainedSeq2seq(nn.Module):
|
|||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def forward(self, *inputs, **kwargs):
|
def forward(self, encoder_input_ids, decoder_input_ids, **kwargs):
|
||||||
""" The forward pass on a seq2eq depends what we are performing:
|
""" The forward pass on a seq2eq depends what we are performing:
|
||||||
|
|
||||||
- During training we perform one forward pass through both the encoder
|
- During training we perform one forward pass through both the encoder
|
||||||
@@ -142,6 +142,11 @@ class PreTrainedSeq2seq(nn.Module):
|
|||||||
Therefore, we skip the forward pass on the encoder if an argument named
|
Therefore, we skip the forward pass on the encoder if an argument named
|
||||||
`encoder_hidden_state` is passed to this function.
|
`encoder_hidden_state` is passed to this function.
|
||||||
|
|
||||||
|
Params:
|
||||||
|
encoder_input_ids: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``
|
||||||
|
Indices of encoder input sequence tokens in the vocabulary.
|
||||||
|
decoder_input_ids: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``
|
||||||
|
Indices of decoder input sequence tokens in the vocabulary.
|
||||||
"""
|
"""
|
||||||
# Separate the encoder- and decoder- specific kwargs. A kwarg is
|
# Separate the encoder- and decoder- specific kwargs. A kwarg is
|
||||||
# decoder-specific it the key starts with `decoder_`
|
# decoder-specific it the key starts with `decoder_`
|
||||||
@@ -154,14 +159,14 @@ class PreTrainedSeq2seq(nn.Module):
|
|||||||
# Encode if needed (training, first prediction pass)
|
# Encode if needed (training, first prediction pass)
|
||||||
encoder_hidden_states = kwargs_encoder.pop('encoder_hidden_states', None)
|
encoder_hidden_states = kwargs_encoder.pop('encoder_hidden_states', None)
|
||||||
if encoder_hidden_states is None:
|
if encoder_hidden_states is None:
|
||||||
encoder_outputs = self.encoder(*inputs, **kwargs_encoder)
|
encoder_outputs = self.encoder(encoder_input_ids, **kwargs_encoder)
|
||||||
encoder_hidden_states = encoder_outputs[0]
|
encoder_hidden_states = encoder_outputs[0]
|
||||||
else:
|
else:
|
||||||
encoder_outputs = ()
|
encoder_outputs = ()
|
||||||
|
|
||||||
# Decode
|
# Decode
|
||||||
kwargs_decoder['encoder_hidden_states'] = encoder_hidden_states
|
kwargs_decoder['encoder_hidden_states'] = encoder_hidden_states
|
||||||
decoder_outputs = self.decoder(**kwargs_decoder)
|
decoder_outputs = self.decoder(decoder_input_ids, **kwargs_decoder)
|
||||||
|
|
||||||
return decoder_outputs + encoder_outputs
|
return decoder_outputs + encoder_outputs
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user