[Encoder-Decoder] Force models outputs to always have batch_size as their first dim (#3536)
* solve conflicts * improve comments
This commit is contained in:
committed by
GitHub
parent
ab5d06a094
commit
390c128592
@@ -116,7 +116,6 @@ class PretrainedBartModel(PreTrainedModel):
|
|||||||
config_class = BartConfig
|
config_class = BartConfig
|
||||||
base_model_prefix = "model"
|
base_model_prefix = "model"
|
||||||
pretrained_model_archive_map = BART_PRETRAINED_MODEL_ARCHIVE_MAP
|
pretrained_model_archive_map = BART_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||||
encoder_outputs_batch_dim_idx = 1 # outputs shaped (seq_len, bs, ...)
|
|
||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
std = self.config.init_std
|
std = self.config.init_std
|
||||||
@@ -294,7 +293,10 @@ class BartEncoder(nn.Module):
|
|||||||
if self.output_hidden_states:
|
if self.output_hidden_states:
|
||||||
encoder_states.append(x)
|
encoder_states.append(x)
|
||||||
|
|
||||||
|
# T x B x C -> B x T x C
|
||||||
encoder_states = [hidden_state.transpose(0, 1) for hidden_state in encoder_states]
|
encoder_states = [hidden_state.transpose(0, 1) for hidden_state in encoder_states]
|
||||||
|
x = x.transpose(0, 1)
|
||||||
|
|
||||||
return x, encoder_states, all_attentions
|
return x, encoder_states, all_attentions
|
||||||
|
|
||||||
|
|
||||||
@@ -448,7 +450,11 @@ class BartDecoder(nn.Module):
|
|||||||
|
|
||||||
x = self.layernorm_embedding(x)
|
x = self.layernorm_embedding(x)
|
||||||
x = F.dropout(x, p=self.dropout, training=self.training)
|
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||||
x = x.transpose(0, 1) # (seq_len, BS, model_dim)
|
|
||||||
|
# Convert to Bart output format: (seq_len, BS, model_dim) -> (BS, seq_len, model_dim)
|
||||||
|
x = x.transpose(0, 1)
|
||||||
|
encoder_hidden_states = encoder_hidden_states.transpose(0, 1)
|
||||||
|
|
||||||
# decoder layers
|
# decoder layers
|
||||||
all_hidden_states = ()
|
all_hidden_states = ()
|
||||||
all_self_attns = ()
|
all_self_attns = ()
|
||||||
@@ -477,9 +483,10 @@ class BartDecoder(nn.Module):
|
|||||||
if self.output_attentions:
|
if self.output_attentions:
|
||||||
all_self_attns += (layer_self_attn,)
|
all_self_attns += (layer_self_attn,)
|
||||||
|
|
||||||
# Convert shapes from (seq_len, BS, model_dim) to (BS, seq_len, model_dim)
|
# Convert to standart output format: (seq_len, BS, model_dim) -> (BS, seq_len, model_dim)
|
||||||
all_hidden_states = [hidden_state.transpose(0, 1) for hidden_state in all_hidden_states]
|
all_hidden_states = [hidden_state.transpose(0, 1) for hidden_state in all_hidden_states]
|
||||||
x = x.transpose(0, 1)
|
x = x.transpose(0, 1)
|
||||||
|
encoder_hidden_states = encoder_hidden_states.transpose(0, 1)
|
||||||
|
|
||||||
if self.output_past:
|
if self.output_past:
|
||||||
next_cache = ((encoder_hidden_states, encoder_padding_mask), next_decoder_cache)
|
next_cache = ((encoder_hidden_states, encoder_padding_mask), next_decoder_cache)
|
||||||
@@ -930,10 +937,9 @@ class BartForConditionalGeneration(PretrainedBartModel):
|
|||||||
layer_past_new = {
|
layer_past_new = {
|
||||||
attn_key: _reorder_buffer(attn_cache, beam_idx) for attn_key, attn_cache in layer_past.items()
|
attn_key: _reorder_buffer(attn_cache, beam_idx) for attn_key, attn_cache in layer_past.items()
|
||||||
}
|
}
|
||||||
# reordered_layer_past = [layer_past[:, i].unsqueeze(1).clone().detach() for i in beam_idx]
|
|
||||||
# reordered_layer_past = torch.cat(reordered_layer_past, dim=1)
|
|
||||||
reordered_past.append(layer_past_new)
|
reordered_past.append(layer_past_new)
|
||||||
new_enc_out = enc_out if enc_out is None else enc_out.index_select(1, beam_idx)
|
|
||||||
|
new_enc_out = enc_out if enc_out is None else enc_out.index_select(0, beam_idx)
|
||||||
new_enc_mask = enc_mask if enc_mask is None else enc_mask.index_select(0, beam_idx)
|
new_enc_mask = enc_mask if enc_mask is None else enc_mask.index_select(0, beam_idx)
|
||||||
|
|
||||||
past = ((new_enc_out, new_enc_mask), reordered_past)
|
past = ((new_enc_out, new_enc_mask), reordered_past)
|
||||||
|
|||||||
@@ -457,7 +457,6 @@ class T5PreTrainedModel(PreTrainedModel):
|
|||||||
pretrained_model_archive_map = T5_PRETRAINED_MODEL_ARCHIVE_MAP
|
pretrained_model_archive_map = T5_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||||
load_tf_weights = load_tf_weights_in_t5
|
load_tf_weights = load_tf_weights_in_t5
|
||||||
base_model_prefix = "transformer"
|
base_model_prefix = "transformer"
|
||||||
encoder_outputs_batch_dim_idx = 0 # outputs shaped (bs, ...)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dummy_inputs(self):
|
def dummy_inputs(self):
|
||||||
|
|||||||
@@ -948,18 +948,21 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
device=next(self.parameters()).device,
|
device=next(self.parameters()).device,
|
||||||
)
|
)
|
||||||
cur_len = 1
|
cur_len = 1
|
||||||
batch_idx = self.encoder_outputs_batch_dim_idx
|
|
||||||
assert (
|
assert (
|
||||||
batch_size == encoder_outputs[0].shape[batch_idx]
|
batch_size == encoder_outputs[0].shape[0]
|
||||||
), f"expected encoder_outputs[0] to have 1st dimension bs={batch_size}, got {encoder_outputs[0].shape[1]} "
|
), f"expected encoder_outputs[0] to have 1st dimension bs={batch_size}, got {encoder_outputs[0].shape[0]} "
|
||||||
expanded_idx = (
|
|
||||||
|
# expand batch_idx to assign correct encoder output for expanded input_ids (due to num_beams > 1 and num_return_sequences > 1)
|
||||||
|
expanded_batch_idxs = (
|
||||||
torch.arange(batch_size)
|
torch.arange(batch_size)
|
||||||
.view(-1, 1)
|
.view(-1, 1)
|
||||||
.repeat(1, num_beams * effective_batch_mult)
|
.repeat(1, num_beams * effective_batch_mult)
|
||||||
.view(-1)
|
.view(-1)
|
||||||
.to(input_ids.device)
|
.to(input_ids.device)
|
||||||
)
|
)
|
||||||
encoder_outputs = (encoder_outputs[0].index_select(batch_idx, expanded_idx), *encoder_outputs[1:])
|
# expand encoder_outputs
|
||||||
|
encoder_outputs = (encoder_outputs[0].index_select(0, expanded_batch_idxs), *encoder_outputs[1:])
|
||||||
|
|
||||||
else:
|
else:
|
||||||
encoder_outputs = None
|
encoder_outputs = None
|
||||||
|
|||||||
Reference in New Issue
Block a user