From 95ec1d08bef7b780653f9f2c59fddb4a97873cff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Wed, 16 Oct 2019 20:55:42 +0200 Subject: [PATCH] separate inputs into encoder & decoder inputs --- transformers/modeling_seq2seq.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/transformers/modeling_seq2seq.py b/transformers/modeling_seq2seq.py index b326f2bc1e..8f27224a56 100644 --- a/transformers/modeling_seq2seq.py +++ b/transformers/modeling_seq2seq.py @@ -130,7 +130,7 @@ class PreTrainedSeq2seq(nn.Module): return model - def forward(self, *inputs, **kwargs): + def forward(self, encoder_input_ids, decoder_input_ids, **kwargs): """ The forward pass on a seq2eq depends what we are performing: - During training we perform one forward pass through both the encoder @@ -142,6 +142,11 @@ class PreTrainedSeq2seq(nn.Module): Therefore, we skip the forward pass on the encoder if an argument named `encoder_hidden_state` is passed to this function. + Params: + encoder_input_ids: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)`` + Indices of encoder input sequence tokens in the vocabulary. + decoder_input_ids: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)`` + Indices of decoder input sequence tokens in the vocabulary. """ # Separate the encoder- and decoder- specific kwargs. A kwarg is # decoder-specific it the key starts with `decoder_` @@ -154,14 +159,14 @@ class PreTrainedSeq2seq(nn.Module): # Encode if needed (training, first prediction pass) encoder_hidden_states = kwargs_encoder.pop('encoder_hidden_states', None) if encoder_hidden_states is None: - encoder_outputs = self.encoder(*inputs, **kwargs_encoder) + encoder_outputs = self.encoder(encoder_input_ids, **kwargs_encoder) encoder_hidden_states = encoder_outputs[0] else: encoder_outputs = () # Decode kwargs_decoder['encoder_hidden_states'] = encoder_hidden_states - decoder_outputs = self.decoder(**kwargs_decoder) + decoder_outputs = self.decoder(decoder_input_ids, **kwargs_decoder) return decoder_outputs + encoder_outputs