[Seq2Seq Generation] Call encoder before expanding input_ids (#3370)
This commit is contained in:
@@ -113,6 +113,7 @@ class PretrainedBartModel(PreTrainedModel):
|
|||||||
config_class = BartConfig
|
config_class = BartConfig
|
||||||
base_model_prefix = "model"
|
base_model_prefix = "model"
|
||||||
pretrained_model_archive_map = BART_PRETRAINED_MODEL_ARCHIVE_MAP
|
pretrained_model_archive_map = BART_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||||
|
encoder_outputs_batch_dim_idx = 1 # outputs shaped (seq_len, bs, ...)
|
||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
std = self.config.init_std
|
std = self.config.init_std
|
||||||
@@ -888,7 +889,6 @@ class BartForConditionalGeneration(PretrainedBartModel):
|
|||||||
encoder_outputs, decoder_cached_states = past, None
|
encoder_outputs, decoder_cached_states = past, None
|
||||||
else:
|
else:
|
||||||
encoder_outputs, decoder_cached_states = past
|
encoder_outputs, decoder_cached_states = past
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"input_ids": None, # encoder_outputs is defined. input_ids not needed
|
"input_ids": None, # encoder_outputs is defined. input_ids not needed
|
||||||
"encoder_outputs": encoder_outputs,
|
"encoder_outputs": encoder_outputs,
|
||||||
|
|||||||
@@ -457,6 +457,7 @@ class T5PreTrainedModel(PreTrainedModel):
|
|||||||
pretrained_model_archive_map = T5_PRETRAINED_MODEL_ARCHIVE_MAP
|
pretrained_model_archive_map = T5_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||||
load_tf_weights = load_tf_weights_in_t5
|
load_tf_weights = load_tf_weights_in_t5
|
||||||
base_model_prefix = "transformer"
|
base_model_prefix = "transformer"
|
||||||
|
encoder_outputs_batch_dim_idx = 0 # outputs shaped (bs, ...)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dummy_inputs(self):
|
def dummy_inputs(self):
|
||||||
|
|||||||
@@ -895,6 +895,21 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
effective_batch_size = batch_size
|
effective_batch_size = batch_size
|
||||||
effective_batch_mult = 1
|
effective_batch_mult = 1
|
||||||
|
|
||||||
|
if self.config.is_encoder_decoder:
|
||||||
|
if decoder_start_token_id is None:
|
||||||
|
decoder_start_token_id = bos_token_id
|
||||||
|
|
||||||
|
assert (
|
||||||
|
decoder_start_token_id is not None
|
||||||
|
), "decoder_start_token_id or bos_token_id has to be defined for encoder-decoder generation"
|
||||||
|
assert hasattr(self, "get_encoder"), "{} should have a 'get_encoder' function defined".format(self)
|
||||||
|
assert callable(self.get_encoder), "{} should be a method".format(self.get_encoder)
|
||||||
|
|
||||||
|
# get encoder and store encoder outputs
|
||||||
|
encoder = self.get_encoder()
|
||||||
|
|
||||||
|
encoder_outputs = encoder(input_ids, attention_mask=attention_mask)
|
||||||
|
|
||||||
# Expand input ids if num_beams > 1 or num_return_sequences > 1
|
# Expand input ids if num_beams > 1 or num_return_sequences > 1
|
||||||
if num_return_sequences > 1 or num_beams > 1:
|
if num_return_sequences > 1 or num_beams > 1:
|
||||||
input_ids_len = input_ids.shape[-1]
|
input_ids_len = input_ids.shape[-1]
|
||||||
@@ -911,20 +926,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
) # shape: (batch_size * num_return_sequences * num_beams, cur_len)
|
) # shape: (batch_size * num_return_sequences * num_beams, cur_len)
|
||||||
|
|
||||||
if self.config.is_encoder_decoder:
|
if self.config.is_encoder_decoder:
|
||||||
if decoder_start_token_id is None:
|
|
||||||
decoder_start_token_id = bos_token_id
|
|
||||||
|
|
||||||
assert (
|
|
||||||
decoder_start_token_id is not None
|
|
||||||
), "decoder_start_token_id or bos_token_id has to be defined for encoder-decoder generation"
|
|
||||||
assert hasattr(self, "get_encoder"), "{} should have a 'get_encoder' function defined".format(self)
|
|
||||||
assert callable(self.get_encoder), "{} should be a method".format(self.get_encoder)
|
|
||||||
|
|
||||||
# get encoder and store encoder outputs
|
|
||||||
encoder = self.get_encoder()
|
|
||||||
|
|
||||||
encoder_outputs = encoder(input_ids, attention_mask=attention_mask)
|
|
||||||
|
|
||||||
# create empty decoder_input_ids
|
# create empty decoder_input_ids
|
||||||
input_ids = torch.full(
|
input_ids = torch.full(
|
||||||
(effective_batch_size * num_beams, 1),
|
(effective_batch_size * num_beams, 1),
|
||||||
@@ -933,6 +934,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
device=next(self.parameters()).device,
|
device=next(self.parameters()).device,
|
||||||
)
|
)
|
||||||
cur_len = 1
|
cur_len = 1
|
||||||
|
batch_idx = self.encoder_outputs_batch_dim_idx
|
||||||
|
assert (
|
||||||
|
batch_size == encoder_outputs[0].shape[batch_idx]
|
||||||
|
), f"expected encoder_outputs[0] to have 1st dimension bs={batch_size}, got {encoder_outputs[0].shape[1]} "
|
||||||
|
expanded_idx = (
|
||||||
|
torch.arange(batch_size)
|
||||||
|
.view(-1, 1)
|
||||||
|
.repeat(1, num_beams * effective_batch_mult)
|
||||||
|
.view(-1)
|
||||||
|
.to(input_ids.device)
|
||||||
|
)
|
||||||
|
encoder_outputs = (encoder_outputs[0].index_select(batch_idx, expanded_idx), *encoder_outputs[1:])
|
||||||
else:
|
else:
|
||||||
encoder_outputs = None
|
encoder_outputs = None
|
||||||
cur_len = input_ids.shape[-1]
|
cur_len = input_ids.shape[-1]
|
||||||
|
|||||||
Reference in New Issue
Block a user