Add support to provide initial tokens to decoder of encoder-decoder type models (#7577)
* Add support to provide initial tokens for decoding * Add docstring * improve code quality * code reformat * code reformat * minor change * remove appending decoder start token Co-authored-by: Ayush Jain <a.jain@sprinklr.com>
This commit is contained in:
@@ -111,6 +111,7 @@ class GenerationMixin:
|
||||
def generate(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
decoder_input_ids: Optional[torch.LongTensor] = None,
|
||||
max_length: Optional[int] = None,
|
||||
min_length: Optional[int] = None,
|
||||
do_sample: Optional[bool] = None,
|
||||
@@ -151,6 +152,9 @@ class GenerationMixin:
|
||||
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||
The sequence used as a prompt for the generation. If :obj:`None` the method initializes
|
||||
it as an empty :obj:`torch.LongTensor` of shape :obj:`(1,)`.
|
||||
decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||
initial input_ids for the decoder of encoder-decoder type models. If :obj:`None` then only
|
||||
decoder_start_token_id is passed as the first token to the decoder.
|
||||
max_length (:obj:`int`, `optional`, defaults to 20):
|
||||
The maximum length of the sequence to be generated.
|
||||
min_length (:obj:`int`, `optional`, defaults to 10):
|
||||
@@ -417,14 +421,19 @@ class GenerationMixin:
|
||||
) # shape: (batch_size * num_return_sequences * num_beams, cur_len)
|
||||
|
||||
if self.config.is_encoder_decoder:
|
||||
device = next(self.parameters()).device
|
||||
if decoder_input_ids is not None:
|
||||
# give initial decoder input ids
|
||||
input_ids = decoder_input_ids.repeat(effective_batch_size * num_beams, 1).to(device)
|
||||
else:
|
||||
# create empty decoder input_ids
|
||||
input_ids = torch.full(
|
||||
(effective_batch_size * num_beams, 1),
|
||||
decoder_start_token_id,
|
||||
dtype=torch.long,
|
||||
device=next(self.parameters()).device,
|
||||
device=device,
|
||||
)
|
||||
cur_len = 1
|
||||
cur_len = input_ids.shape[-1]
|
||||
|
||||
assert (
|
||||
batch_size == encoder_outputs.last_hidden_state.shape[0]
|
||||
|
||||
Reference in New Issue
Block a user