From 776e82d2bedcc4dea509c6904ee4f48370d08838 Mon Sep 17 00:00:00 2001 From: ayushtiku5 <40797286+ayushtiku5@users.noreply.github.com> Date: Mon, 19 Oct 2020 12:26:08 +0530 Subject: [PATCH] Add support to provide initial tokens to decoder of encoder-decoder type models (#7577) * Add support to provide initial tokens for decoding * Add docstring * improve code quality * code reformat * code reformat * minor change * remove appending decoder start token Co-authored-by: Ayush Jain --- src/transformers/generation_utils.py | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index 827d8dc421..627beeca60 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -111,6 +111,7 @@ class GenerationMixin: def generate( self, input_ids: Optional[torch.LongTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, max_length: Optional[int] = None, min_length: Optional[int] = None, do_sample: Optional[bool] = None, @@ -151,6 +152,9 @@ class GenerationMixin: input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): The sequence used as a prompt for the generation. If :obj:`None` the method initializes it as an empty :obj:`torch.LongTensor` of shape :obj:`(1,)`. + decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + initial input_ids for the decoder of encoder-decoder type models. If :obj:`None` then only + decoder_start_token_id is passed as the first token to the decoder. max_length (:obj:`int`, `optional`, defaults to 20): The maximum length of the sequence to be generated. min_length (:obj:`int`, `optional`, defaults to 10): @@ -417,14 +421,19 @@ class GenerationMixin: ) # shape: (batch_size * num_return_sequences * num_beams, cur_len) if self.config.is_encoder_decoder: - # create empty decoder input_ids - input_ids = torch.full( - (effective_batch_size * num_beams, 1), - decoder_start_token_id, - dtype=torch.long, - device=next(self.parameters()).device, - ) - cur_len = 1 + device = next(self.parameters()).device + if decoder_input_ids is not None: + # give initial decoder input ids + input_ids = decoder_input_ids.repeat(effective_batch_size * num_beams, 1).to(device) + else: + # create empty decoder input_ids + input_ids = torch.full( + (effective_batch_size * num_beams, 1), + decoder_start_token_id, + dtype=torch.long, + device=device, + ) + cur_len = input_ids.shape[-1] assert ( batch_size == encoder_outputs.last_hidden_state.shape[0]