From 1a5aefc95c2fb78f712422c2bbdbce5a00ae862c Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Thu, 26 Mar 2020 18:41:19 -0400 Subject: [PATCH] [Seq2Seq Generation] Call encoder before expanding input_ids (#3370) --- src/transformers/modeling_bart.py | 2 +- src/transformers/modeling_t5.py | 1 + src/transformers/modeling_utils.py | 41 ++++++++++++++++++++---------- 3 files changed, 29 insertions(+), 15 deletions(-) diff --git a/src/transformers/modeling_bart.py b/src/transformers/modeling_bart.py index a35fae9ae9..23c513393c 100644 --- a/src/transformers/modeling_bart.py +++ b/src/transformers/modeling_bart.py @@ -113,6 +113,7 @@ class PretrainedBartModel(PreTrainedModel): config_class = BartConfig base_model_prefix = "model" pretrained_model_archive_map = BART_PRETRAINED_MODEL_ARCHIVE_MAP + encoder_outputs_batch_dim_idx = 1 # outputs shaped (seq_len, bs, ...) def _init_weights(self, module): std = self.config.init_std @@ -888,7 +889,6 @@ class BartForConditionalGeneration(PretrainedBartModel): encoder_outputs, decoder_cached_states = past, None else: encoder_outputs, decoder_cached_states = past - return { "input_ids": None, # encoder_outputs is defined. input_ids not needed "encoder_outputs": encoder_outputs, diff --git a/src/transformers/modeling_t5.py b/src/transformers/modeling_t5.py index b56917ae1d..d03c40f5a6 100644 --- a/src/transformers/modeling_t5.py +++ b/src/transformers/modeling_t5.py @@ -457,6 +457,7 @@ class T5PreTrainedModel(PreTrainedModel): pretrained_model_archive_map = T5_PRETRAINED_MODEL_ARCHIVE_MAP load_tf_weights = load_tf_weights_in_t5 base_model_prefix = "transformer" + encoder_outputs_batch_dim_idx = 0 # outputs shaped (bs, ...) @property def dummy_inputs(self): diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 61f00c6eb6..e1f5fd2af2 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -895,6 +895,21 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): effective_batch_size = batch_size effective_batch_mult = 1 + if self.config.is_encoder_decoder: + if decoder_start_token_id is None: + decoder_start_token_id = bos_token_id + + assert ( + decoder_start_token_id is not None + ), "decoder_start_token_id or bos_token_id has to be defined for encoder-decoder generation" + assert hasattr(self, "get_encoder"), "{} should have a 'get_encoder' function defined".format(self) + assert callable(self.get_encoder), "{} should be a method".format(self.get_encoder) + + # get encoder and store encoder outputs + encoder = self.get_encoder() + + encoder_outputs = encoder(input_ids, attention_mask=attention_mask) + # Expand input ids if num_beams > 1 or num_return_sequences > 1 if num_return_sequences > 1 or num_beams > 1: input_ids_len = input_ids.shape[-1] @@ -911,20 +926,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ) # shape: (batch_size * num_return_sequences * num_beams, cur_len) if self.config.is_encoder_decoder: - if decoder_start_token_id is None: - decoder_start_token_id = bos_token_id - - assert ( - decoder_start_token_id is not None - ), "decoder_start_token_id or bos_token_id has to be defined for encoder-decoder generation" - assert hasattr(self, "get_encoder"), "{} should have a 'get_encoder' function defined".format(self) - assert callable(self.get_encoder), "{} should be a method".format(self.get_encoder) - - # get encoder and store encoder outputs - encoder = self.get_encoder() - - encoder_outputs = encoder(input_ids, attention_mask=attention_mask) - # create empty decoder_input_ids input_ids = torch.full( (effective_batch_size * num_beams, 1), @@ -933,6 +934,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): device=next(self.parameters()).device, ) cur_len = 1 + batch_idx = self.encoder_outputs_batch_dim_idx + assert ( + batch_size == encoder_outputs[0].shape[batch_idx] + ), f"expected encoder_outputs[0] to have 1st dimension bs={batch_size}, got {encoder_outputs[0].shape[1]} " + expanded_idx = ( + torch.arange(batch_size) + .view(-1, 1) + .repeat(1, num_beams * effective_batch_mult) + .view(-1) + .to(input_ids.device) + ) + encoder_outputs = (encoder_outputs[0].index_select(batch_idx, expanded_idx), *encoder_outputs[1:]) else: encoder_outputs = None cur_len = input_ids.shape[-1]