From 390c1285925dd119705e69a266202ef04490d012 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 2 Apr 2020 15:18:33 +0200 Subject: [PATCH] [Encoder-Decoder] Force models outputs to always have batch_size as their first dim (#3536) * solve conflicts * improve comments --- src/transformers/modeling_bart.py | 18 ++++++++++++------ src/transformers/modeling_t5.py | 1 - src/transformers/modeling_utils.py | 13 ++++++++----- 3 files changed, 20 insertions(+), 12 deletions(-) diff --git a/src/transformers/modeling_bart.py b/src/transformers/modeling_bart.py index d237c0448d..bec39811b6 100644 --- a/src/transformers/modeling_bart.py +++ b/src/transformers/modeling_bart.py @@ -116,7 +116,6 @@ class PretrainedBartModel(PreTrainedModel): config_class = BartConfig base_model_prefix = "model" pretrained_model_archive_map = BART_PRETRAINED_MODEL_ARCHIVE_MAP - encoder_outputs_batch_dim_idx = 1 # outputs shaped (seq_len, bs, ...) def _init_weights(self, module): std = self.config.init_std @@ -294,7 +293,10 @@ class BartEncoder(nn.Module): if self.output_hidden_states: encoder_states.append(x) + # T x B x C -> B x T x C encoder_states = [hidden_state.transpose(0, 1) for hidden_state in encoder_states] + x = x.transpose(0, 1) + return x, encoder_states, all_attentions @@ -448,7 +450,11 @@ class BartDecoder(nn.Module): x = self.layernorm_embedding(x) x = F.dropout(x, p=self.dropout, training=self.training) - x = x.transpose(0, 1) # (seq_len, BS, model_dim) + + # Convert to Bart output format: (seq_len, BS, model_dim) -> (BS, seq_len, model_dim) + x = x.transpose(0, 1) + encoder_hidden_states = encoder_hidden_states.transpose(0, 1) + # decoder layers all_hidden_states = () all_self_attns = () @@ -477,9 +483,10 @@ class BartDecoder(nn.Module): if self.output_attentions: all_self_attns += (layer_self_attn,) - # Convert shapes from (seq_len, BS, model_dim) to (BS, seq_len, model_dim) + # Convert to standart output format: (seq_len, BS, model_dim) -> (BS, seq_len, model_dim) all_hidden_states = [hidden_state.transpose(0, 1) for hidden_state in all_hidden_states] x = x.transpose(0, 1) + encoder_hidden_states = encoder_hidden_states.transpose(0, 1) if self.output_past: next_cache = ((encoder_hidden_states, encoder_padding_mask), next_decoder_cache) @@ -930,10 +937,9 @@ class BartForConditionalGeneration(PretrainedBartModel): layer_past_new = { attn_key: _reorder_buffer(attn_cache, beam_idx) for attn_key, attn_cache in layer_past.items() } - # reordered_layer_past = [layer_past[:, i].unsqueeze(1).clone().detach() for i in beam_idx] - # reordered_layer_past = torch.cat(reordered_layer_past, dim=1) reordered_past.append(layer_past_new) - new_enc_out = enc_out if enc_out is None else enc_out.index_select(1, beam_idx) + + new_enc_out = enc_out if enc_out is None else enc_out.index_select(0, beam_idx) new_enc_mask = enc_mask if enc_mask is None else enc_mask.index_select(0, beam_idx) past = ((new_enc_out, new_enc_mask), reordered_past) diff --git a/src/transformers/modeling_t5.py b/src/transformers/modeling_t5.py index 98fc319a4e..5235629c60 100644 --- a/src/transformers/modeling_t5.py +++ b/src/transformers/modeling_t5.py @@ -457,7 +457,6 @@ class T5PreTrainedModel(PreTrainedModel): pretrained_model_archive_map = T5_PRETRAINED_MODEL_ARCHIVE_MAP load_tf_weights = load_tf_weights_in_t5 base_model_prefix = "transformer" - encoder_outputs_batch_dim_idx = 0 # outputs shaped (bs, ...) @property def dummy_inputs(self): diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 2aad833ce9..685605e773 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -948,18 +948,21 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): device=next(self.parameters()).device, ) cur_len = 1 - batch_idx = self.encoder_outputs_batch_dim_idx + assert ( - batch_size == encoder_outputs[0].shape[batch_idx] - ), f"expected encoder_outputs[0] to have 1st dimension bs={batch_size}, got {encoder_outputs[0].shape[1]} " - expanded_idx = ( + batch_size == encoder_outputs[0].shape[0] + ), f"expected encoder_outputs[0] to have 1st dimension bs={batch_size}, got {encoder_outputs[0].shape[0]} " + + # expand batch_idx to assign correct encoder output for expanded input_ids (due to num_beams > 1 and num_return_sequences > 1) + expanded_batch_idxs = ( torch.arange(batch_size) .view(-1, 1) .repeat(1, num_beams * effective_batch_mult) .view(-1) .to(input_ids.device) ) - encoder_outputs = (encoder_outputs[0].index_select(batch_idx, expanded_idx), *encoder_outputs[1:]) + # expand encoder_outputs + encoder_outputs = (encoder_outputs[0].index_select(0, expanded_batch_idxs), *encoder_outputs[1:]) else: encoder_outputs = None