From 9e5c28c573978f33ecb2eeeb1670d7279d7ab484 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 14 Dec 2023 13:31:13 +0000 Subject: [PATCH] Generate: assisted decoding now uses `generate` for the assistant (#28030) generate refactor --- .../generation/candidate_generator.py | 74 +++----- src/transformers/generation/utils.py | 2 +- tests/generation/test_utils.py | 160 ++++++++++-------- 3 files changed, 116 insertions(+), 120 deletions(-) diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index 7cceac3364..ccfd4cfad7 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -15,7 +15,7 @@ import copy import warnings -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union import torch @@ -28,7 +28,7 @@ if TYPE_CHECKING: class CandidateGenerator: """Abstract base class for all candidate generators that can be applied during assisted generation.""" - def get_candidates(self, input_ids: torch.LongTensor) -> torch.LongTensor: + def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]: """ Fetches the candidates to be tried for the current input. @@ -37,8 +37,9 @@ class CandidateGenerator: Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids) Return: - `torch.LongTensor` of shape `(num_candidates, candidate_length)`: The candidate sequences to be assessed by - the model. + `torch.LongTensor` of shape `(batch_size, candidate_length)` containing the candidate sequences to be + assessed by the model and, optionally, a `torch.FloatTensor` of shape `(batch_size, candidate_length, + vocabulary_size)` containing the logits associated to each candidate. """ raise NotImplementedError( f"{self.__class__} is an abstract class. Only classes inheriting this class can call `get_candidates`." @@ -152,7 +153,7 @@ class AssistedCandidateGenerator(CandidateGenerator): ) self.logits_processor = logits_processor - def get_candidates(self, input_ids: torch.LongTensor) -> torch.LongTensor: + def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]: """ Fetches the candidates to be tried for the current input. @@ -161,7 +162,9 @@ class AssistedCandidateGenerator(CandidateGenerator): Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids) Return: - `torch.LongTensor` of shape `(num_candidates, candidate_length)`: The candidate sequences to be tried. + `torch.LongTensor` of shape `(batch_size, candidate_length)` containing the candidate sequences to be + assessed by the model and a `torch.FloatTensor` of shape `(batch_size, candidate_length, + vocabulary_size)` containing the logits associated to each candidate. """ # 1. If it is not the first round of candidate generation, prepare the inputs based on the input_ids length # (which implicitly contains the number of accepted candidates from the previous round) @@ -179,51 +182,24 @@ class AssistedCandidateGenerator(CandidateGenerator): ) self.assistant_kwargs = _prepare_token_type_ids(self.assistant_kwargs, new_cur_len) - # 2. Forecast next N tokens using the assistant model. This `for` block can be replaced with a `.generate()` - # call if we decide to add `past_key_values` as a possible output of generate, as we need access to the - # assistant cache to secure strong speedups. - candidate_input_ids = input_ids - for _ in range(int(self.num_assistant_tokens)): - # 2.1 prepare assistant model inputs - assistant_inputs = self.assistant_model.prepare_inputs_for_generation( - candidate_input_ids, - **self.assistant_kwargs, - ) + # 2. Forecast next N tokens using the assistant model. + assistant_generation_kwargs = { + self.input_ids_key: input_ids, + "do_sample": False, + "num_beams": 1, + "max_new_tokens": int(self.num_assistant_tokens), + "return_dict_in_generate": True, + "output_scores": True, + } + assistant_output = self.assistant_model.generate(**assistant_generation_kwargs, **self.assistant_kwargs) - # 2.2. check if the input ids length is correct - has_past_key_values = assistant_inputs.get("past_key_values", None) is not None - if has_past_key_values and assistant_inputs[self.input_ids_key].shape[-1] not in (1, 2): - raise ValueError("The length of the input ids in assistant inputs should be 1 or 2") + # 3. Update variables for the next round of candidate generation + self.assistant_kwargs["past_key_values"] = assistant_output.past_key_values - # 2.3. use the assistant model to obtain the next candidate logits - assistant_model_outputs = self.assistant_model(**assistant_inputs) - - # 2.4. greedily select the next candidate token - if len(self.logits_processor) > 0: - assistant_model_outputs.logits[:, -1, :] = self.logits_processor( - candidate_input_ids, assistant_model_outputs.logits[:, -1, :] - ) - new_token = assistant_model_outputs.logits[:, -1, :].argmax(dim=-1) - candidate_input_ids = torch.cat((candidate_input_ids, new_token[:, None]), dim=-1) - - # 2.5. update assistant model inputs - if self.assistant_kwargs.get(self.attention_key, None) is not None: - mask = self.assistant_kwargs[self.attention_key] - self.assistant_kwargs[self.attention_key] = torch.cat( - [mask, mask.new_ones((mask.shape[0], 1))], dim=-1 - ) - self.assistant_kwargs["past_key_values"] = assistant_model_outputs.past_key_values - - # 2.6. stop assistant generation on EOS - if self.eos_token_id_tensor is not None: - last_assistant_token_is_eos = new_token.tile(self.eos_token_id_tensor.shape[0], 1) - last_assistant_token_is_eos = ( - ~last_assistant_token_is_eos.ne(self.eos_token_id_tensor.unsqueeze(1)).prod(dim=0).bool() - ) - if last_assistant_token_is_eos: - break - - return candidate_input_ids + # 4. Prepare variables for output + candidate_logits = torch.stack(assistant_output.scores, dim=1) + candidate_ids = assistant_output.sequences + return candidate_ids, candidate_logits def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, num_matches: int): """ diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index d7510951b1..d23f7f9245 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -4585,7 +4585,7 @@ class GenerationMixin: cur_len = input_ids.shape[-1] # 1. Fetch candidate sequences from a `CandidateGenerator` - candidate_input_ids = candidate_generator.get_candidates(input_ids) + candidate_input_ids, candidate_logits = candidate_generator.get_candidates(input_ids) candidate_length = candidate_input_ids.shape[1] - input_ids.shape[1] last_assistant_token_is_eos = ( ~candidate_input_ids[:, -1] diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 6e11818f69..973f54f003 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -3128,85 +3128,26 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi self.assertListEqual(outputs_assisted.tolist(), outputs_tti.tolist()) def test_model_kwarg_assisted_decoding_encoder_decoder(self): + """ + Tests that the following scenario is compatible with assisted generation: + 1. encoder-decoder main model + 2. encoder-decoder assistant model + 3. both have a custom input + (e.g. Whisper) + """ + # PT-only test: TF doesn't support assisted decoding yet. # Bart subclass with a kwarg that distorts the output class FakeBart(BartForConditionalGeneration): - def forward(self, input_ids, foo=False, **kwargs): - outs = super().forward(input_ids, **kwargs) - + def forward(self, input_ids, past_key_values, foo=False, **kwargs): + outs = super().forward(input_ids, past_key_values=past_key_values, **kwargs) if foo: outs["logits"][:, :, :] = 0.0 - return outs def prepare_inputs_for_generation(self, *args, foo=False, encoder_outputs=None, **kwargs): kwargs["encoder_outputs"] = encoder_outputs inputs = super().prepare_inputs_for_generation(*args, **kwargs) - - inputs["foo"] = foo - return inputs - - model = FakeBart.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration").to( - torch_device - ) - tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration") - - text = "Hello world" - tokenized_inputs = tokenizer([text], return_tensors="pt") - input_ids = tokenized_inputs.input_ids.to(torch_device) - - # Traditional way of generating text - outputs_normal = model.generate(input_ids) - self.assertEqual(outputs_normal.shape, (1, 20)) - - # Should be different with foo - outputs_foo = model.generate( - input_ids, - foo=True, - ) - with self.assertRaises(AssertionError): - self.assertListEqual(outputs_foo.tolist(), outputs_normal.tolist()) - - # Assistant model - assistant = AutoModelForSeq2SeqLM.from_pretrained( - "hf-internal-testing/tiny-random-BartForConditionalGeneration" - ).to(torch_device) - - # If assisted generation passes model_kwargs correctly, should be same as previous - outputs_assisted = model.generate( - input_ids, - foo=True, - assistant_model=assistant, - ) - self.assertListEqual(outputs_assisted.tolist(), outputs_foo.tolist()) - - # Check that passing encoder_outputs directly also works as expected - encoder_outputs = assistant.get_encoder()(input_ids) - - outputs_assisted = model.generate( - foo=True, - assistant_model=assistant, - encoder_outputs=encoder_outputs, - assistant_encoder_outputs=encoder_outputs, - ) - self.assertListEqual(outputs_assisted.tolist(), outputs_foo.tolist()) - - def test_assisted_decoding_encoder_decoder_shared_encoder(self): - # PT-only test: TF doesn't support assisted decoding yet. - # Bart subclass with a kwarg called foo that distorts the output - class FakeBart(BartForConditionalGeneration): - def forward(self, input_ids, foo=False, **kwargs): - outs = super().forward(input_ids, **kwargs) - - if foo: - outs["logits"][:, :, :] = 0.0 - - return outs - - def prepare_inputs_for_generation(self, *args, foo=False, encoder_outputs=None, **kwargs): - kwargs["encoder_outputs"] = encoder_outputs - inputs = super().prepare_inputs_for_generation(*args, **kwargs) - inputs["foo"] = foo return inputs @@ -3229,7 +3170,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi self.assertListEqual(outputs_foo.tolist(), outputs_normal.tolist()) # Assistant model - assistant = BartForCausalLM.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration").to( + assistant = FakeBart.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration").to( torch_device ) @@ -3241,6 +3182,85 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi ) self.assertListEqual(outputs_assisted.tolist(), outputs_foo.tolist()) + # Check that passing encoder_outputs directly also works as expected + encoder_outputs = assistant.get_encoder()(input_ids) + + outputs_assisted = model.generate( + foo=True, + assistant_model=assistant, + encoder_outputs=encoder_outputs, + assistant_encoder_outputs=encoder_outputs, + ) + self.assertListEqual(outputs_assisted.tolist(), outputs_foo.tolist()) + + def test_assisted_decoding_encoder_decoder_shared_encoder(self): + """ + Tests that the following scenario is compatible with assisted generation: + 1. encoder-decoder main model + 2. decoder-only assistant model + 3. both have a custom input + (e.g. DistilWhisper) + """ + + # PT-only test: TF doesn't support assisted decoding yet. + # Bart subclass with a kwarg called foo that distorts the output + class FakeBartSeq2Seq(BartForConditionalGeneration): + def forward(self, input_ids, foo=False, **kwargs): + outs = super().forward(input_ids, **kwargs) + if foo: + outs["logits"][:, :, :] = 0.0 + return outs + + def prepare_inputs_for_generation(self, *args, foo=False, encoder_outputs=None, **kwargs): + kwargs["encoder_outputs"] = encoder_outputs + inputs = super().prepare_inputs_for_generation(*args, **kwargs) + inputs["foo"] = foo + return inputs + + class FakeBartCausalLM(BartForCausalLM): + def forward(self, input_ids, attention_mask, past_key_values, foo=False, **kwargs): + outs = super().forward(input_ids, attention_mask, past_key_values=past_key_values, **kwargs) + if foo: + outs["logits"][:, :, :] = 0.0 + return outs + + def prepare_inputs_for_generation(self, *args, foo=False, encoder_outputs=None, **kwargs): + kwargs["encoder_outputs"] = encoder_outputs + inputs = super().prepare_inputs_for_generation(*args, **kwargs) + inputs["foo"] = foo + return inputs + + model = FakeBartSeq2Seq.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration").to( + torch_device + ) + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration") + + text = "Hello world" + tokenized_inputs = tokenizer([text], return_tensors="pt") + input_ids = tokenized_inputs.input_ids.to(torch_device) + + # Traditional way of generating text + outputs_normal = model.generate(input_ids) + self.assertEqual(outputs_normal.shape, (1, 20)) + + # Should be different with foo + outputs_foo = model.generate(input_ids, foo=True) + with self.assertRaises(AssertionError): + self.assertListEqual(outputs_foo.tolist(), outputs_normal.tolist()) + + # Assistant model + assistant = FakeBartCausalLM.from_pretrained( + "hf-internal-testing/tiny-random-BartForConditionalGeneration" + ).to(torch_device) + + # If assisted generation passes model_kwargs correctly, should be same as previous + outputs_assisted = model.generate( + input_ids, + foo=True, + assistant_model=assistant, + ) + self.assertListEqual(outputs_assisted.tolist(), outputs_foo.tolist()) + # Check that passing encoder_outputs directly also works as expected encoder_outputs = model.get_encoder()(input_ids)