[Encoder-Decoder] Force models outputs to always have batch_size as their first dim (#3536)
* solve conflicts * improve comments
This commit is contained in:
committed by
GitHub
parent
ab5d06a094
commit
390c128592
@@ -116,7 +116,6 @@ class PretrainedBartModel(PreTrainedModel):
|
||||
config_class = BartConfig
|
||||
base_model_prefix = "model"
|
||||
pretrained_model_archive_map = BART_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
encoder_outputs_batch_dim_idx = 1 # outputs shaped (seq_len, bs, ...)
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.init_std
|
||||
@@ -294,7 +293,10 @@ class BartEncoder(nn.Module):
|
||||
if self.output_hidden_states:
|
||||
encoder_states.append(x)
|
||||
|
||||
# T x B x C -> B x T x C
|
||||
encoder_states = [hidden_state.transpose(0, 1) for hidden_state in encoder_states]
|
||||
x = x.transpose(0, 1)
|
||||
|
||||
return x, encoder_states, all_attentions
|
||||
|
||||
|
||||
@@ -448,7 +450,11 @@ class BartDecoder(nn.Module):
|
||||
|
||||
x = self.layernorm_embedding(x)
|
||||
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||
x = x.transpose(0, 1) # (seq_len, BS, model_dim)
|
||||
|
||||
# Convert to Bart output format: (seq_len, BS, model_dim) -> (BS, seq_len, model_dim)
|
||||
x = x.transpose(0, 1)
|
||||
encoder_hidden_states = encoder_hidden_states.transpose(0, 1)
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = ()
|
||||
all_self_attns = ()
|
||||
@@ -477,9 +483,10 @@ class BartDecoder(nn.Module):
|
||||
if self.output_attentions:
|
||||
all_self_attns += (layer_self_attn,)
|
||||
|
||||
# Convert shapes from (seq_len, BS, model_dim) to (BS, seq_len, model_dim)
|
||||
# Convert to standart output format: (seq_len, BS, model_dim) -> (BS, seq_len, model_dim)
|
||||
all_hidden_states = [hidden_state.transpose(0, 1) for hidden_state in all_hidden_states]
|
||||
x = x.transpose(0, 1)
|
||||
encoder_hidden_states = encoder_hidden_states.transpose(0, 1)
|
||||
|
||||
if self.output_past:
|
||||
next_cache = ((encoder_hidden_states, encoder_padding_mask), next_decoder_cache)
|
||||
@@ -930,10 +937,9 @@ class BartForConditionalGeneration(PretrainedBartModel):
|
||||
layer_past_new = {
|
||||
attn_key: _reorder_buffer(attn_cache, beam_idx) for attn_key, attn_cache in layer_past.items()
|
||||
}
|
||||
# reordered_layer_past = [layer_past[:, i].unsqueeze(1).clone().detach() for i in beam_idx]
|
||||
# reordered_layer_past = torch.cat(reordered_layer_past, dim=1)
|
||||
reordered_past.append(layer_past_new)
|
||||
new_enc_out = enc_out if enc_out is None else enc_out.index_select(1, beam_idx)
|
||||
|
||||
new_enc_out = enc_out if enc_out is None else enc_out.index_select(0, beam_idx)
|
||||
new_enc_mask = enc_mask if enc_mask is None else enc_mask.index_select(0, beam_idx)
|
||||
|
||||
past = ((new_enc_out, new_enc_mask), reordered_past)
|
||||
|
||||
Reference in New Issue
Block a user