[Encoder-Decoder] Force models outputs to always have batch_size as their first dim (#3536)

* solve conflicts

* improve comments
This commit is contained in:
Patrick von Platen
2020-04-02 15:18:33 +02:00
committed by GitHub
parent ab5d06a094
commit 390c128592
3 changed files with 20 additions and 12 deletions

View File

@@ -116,7 +116,6 @@ class PretrainedBartModel(PreTrainedModel):
config_class = BartConfig
base_model_prefix = "model"
pretrained_model_archive_map = BART_PRETRAINED_MODEL_ARCHIVE_MAP
encoder_outputs_batch_dim_idx = 1 # outputs shaped (seq_len, bs, ...)
def _init_weights(self, module):
std = self.config.init_std
@@ -294,7 +293,10 @@ class BartEncoder(nn.Module):
if self.output_hidden_states:
encoder_states.append(x)
# T x B x C -> B x T x C
encoder_states = [hidden_state.transpose(0, 1) for hidden_state in encoder_states]
x = x.transpose(0, 1)
return x, encoder_states, all_attentions
@@ -448,7 +450,11 @@ class BartDecoder(nn.Module):
x = self.layernorm_embedding(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = x.transpose(0, 1) # (seq_len, BS, model_dim)
# Convert to Bart output format: (seq_len, BS, model_dim) -> (BS, seq_len, model_dim)
x = x.transpose(0, 1)
encoder_hidden_states = encoder_hidden_states.transpose(0, 1)
# decoder layers
all_hidden_states = ()
all_self_attns = ()
@@ -477,9 +483,10 @@ class BartDecoder(nn.Module):
if self.output_attentions:
all_self_attns += (layer_self_attn,)
# Convert shapes from (seq_len, BS, model_dim) to (BS, seq_len, model_dim)
# Convert to standart output format: (seq_len, BS, model_dim) -> (BS, seq_len, model_dim)
all_hidden_states = [hidden_state.transpose(0, 1) for hidden_state in all_hidden_states]
x = x.transpose(0, 1)
encoder_hidden_states = encoder_hidden_states.transpose(0, 1)
if self.output_past:
next_cache = ((encoder_hidden_states, encoder_padding_mask), next_decoder_cache)
@@ -930,10 +937,9 @@ class BartForConditionalGeneration(PretrainedBartModel):
layer_past_new = {
attn_key: _reorder_buffer(attn_cache, beam_idx) for attn_key, attn_cache in layer_past.items()
}
# reordered_layer_past = [layer_past[:, i].unsqueeze(1).clone().detach() for i in beam_idx]
# reordered_layer_past = torch.cat(reordered_layer_past, dim=1)
reordered_past.append(layer_past_new)
new_enc_out = enc_out if enc_out is None else enc_out.index_select(1, beam_idx)
new_enc_out = enc_out if enc_out is None else enc_out.index_select(0, beam_idx)
new_enc_mask = enc_mask if enc_mask is None else enc_mask.index_select(0, beam_idx)
past = ((new_enc_out, new_enc_mask), reordered_past)