[Seq2Seq Generation] Call encoder before expanding input_ids (#3370)

This commit is contained in:
Sam Shleifer
2020-03-26 18:41:19 -04:00
committed by GitHub
parent 39371ee454
commit 1a5aefc95c
3 changed files with 29 additions and 15 deletions

View File

@@ -895,6 +895,21 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
effective_batch_size = batch_size
effective_batch_mult = 1
if self.config.is_encoder_decoder:
if decoder_start_token_id is None:
decoder_start_token_id = bos_token_id
assert (
decoder_start_token_id is not None
), "decoder_start_token_id or bos_token_id has to be defined for encoder-decoder generation"
assert hasattr(self, "get_encoder"), "{} should have a 'get_encoder' function defined".format(self)
assert callable(self.get_encoder), "{} should be a method".format(self.get_encoder)
# get encoder and store encoder outputs
encoder = self.get_encoder()
encoder_outputs = encoder(input_ids, attention_mask=attention_mask)
# Expand input ids if num_beams > 1 or num_return_sequences > 1
if num_return_sequences > 1 or num_beams > 1:
input_ids_len = input_ids.shape[-1]
@@ -911,20 +926,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
) # shape: (batch_size * num_return_sequences * num_beams, cur_len)
if self.config.is_encoder_decoder:
if decoder_start_token_id is None:
decoder_start_token_id = bos_token_id
assert (
decoder_start_token_id is not None
), "decoder_start_token_id or bos_token_id has to be defined for encoder-decoder generation"
assert hasattr(self, "get_encoder"), "{} should have a 'get_encoder' function defined".format(self)
assert callable(self.get_encoder), "{} should be a method".format(self.get_encoder)
# get encoder and store encoder outputs
encoder = self.get_encoder()
encoder_outputs = encoder(input_ids, attention_mask=attention_mask)
# create empty decoder_input_ids
input_ids = torch.full(
(effective_batch_size * num_beams, 1),
@@ -933,6 +934,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
device=next(self.parameters()).device,
)
cur_len = 1
batch_idx = self.encoder_outputs_batch_dim_idx
assert (
batch_size == encoder_outputs[0].shape[batch_idx]
), f"expected encoder_outputs[0] to have 1st dimension bs={batch_size}, got {encoder_outputs[0].shape[1]} "
expanded_idx = (
torch.arange(batch_size)
.view(-1, 1)
.repeat(1, num_beams * effective_batch_mult)
.view(-1)
.to(input_ids.device)
)
encoder_outputs = (encoder_outputs[0].index_select(batch_idx, expanded_idx), *encoder_outputs[1:])
else:
encoder_outputs = None
cur_len = input_ids.shape[-1]