From 1cc7ca32955f618f9dfd081d787769fb898497c1 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Wed, 12 Feb 2025 11:37:19 +0000 Subject: [PATCH] Whisper: remove redundant assisted generation tests (#34814) * remove redundant test * delete another test * revert default max_length * (wrong place, moving) --- .../generation/candidate_generator.py | 7 +- src/transformers/generation/flax_utils.py | 2 +- src/transformers/generation/utils.py | 4 - tests/generation/test_utils.py | 145 ------------------ 4 files changed, 4 insertions(+), 154 deletions(-) diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index 9689ca2b52..2c4ab9c2a9 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -124,7 +124,7 @@ class AssistedCandidateGenerator(CandidateGenerator): # Prepare the kwargs for the assistant model assistant_kwargs = {} for key, value in model_kwargs.items(): # deepcopy crashes if we attempt to copy encoder outputs with grads - if key not in ("encoder_outputs", "assistant_encoder_outputs", "past_key_values"): + if key not in ("encoder_outputs", "past_key_values"): assistant_kwargs[key] = ( value.detach().to(device) if isinstance(value, torch.Tensor) else copy.deepcopy(value) ) @@ -133,9 +133,8 @@ class AssistedCandidateGenerator(CandidateGenerator): if "logits_to_keep" in assistant_kwargs.keys() and not assistant_model._supports_logits_to_keep(): del assistant_kwargs["logits_to_keep"] - if "assistant_encoder_outputs" in model_kwargs: - assistant_kwargs["encoder_outputs"] = model_kwargs["assistant_encoder_outputs"] - elif assistant_model.config.is_encoder_decoder: + # If the assistant is an encoder-decoder model, assume the encoder is different on the assistant. + if assistant_model.config.is_encoder_decoder: inputs_tensor, model_input_name, assistant_kwargs = assistant_model._prepare_model_inputs( inputs_tensor, assistant_model.generation_config.bos_token_id, assistant_kwargs ) diff --git a/src/transformers/generation/flax_utils.py b/src/transformers/generation/flax_utils.py index 134666d45a..0f6f2a0041 100644 --- a/src/transformers/generation/flax_utils.py +++ b/src/transformers/generation/flax_utils.py @@ -396,7 +396,7 @@ class FlaxGenerationMixin: "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)" ) generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length - else: # by default let's always generate 10 new tokens + else: # by default let's always generate 20 new tokens if generation_config.max_length == GenerationConfig().max_length: generation_config.max_length = generation_config.max_length + input_ids_seq_length max_position_embeddings = getattr(self.config, "max_position_embeddings", None) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index a773c4a1d9..22081b3845 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1385,10 +1385,6 @@ class GenerationMixin: decoder_model_args = set(inspect.signature(decoder.forward).parameters) model_args |= {f"decoder_{x}" for x in decoder_model_args} - # allow assistant_encoder_outputs to be passed if we're doing assisted generating - if "assistant_encoder_outputs" in model_kwargs: - model_args |= {"assistant_encoder_outputs"} - for key, value in model_kwargs.items(): if value is not None and key not in model_args: unused_model_args.append(key) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 6833fd476e..dc88091f4c 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -63,7 +63,6 @@ if is_torch_available(): AutoModelForVision2Seq, AutoProcessor, AutoTokenizer, - BartForCausalLM, BartForConditionalGeneration, BartTokenizer, GPT2LMHeadModel, @@ -3629,150 +3628,6 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi ) self.assertListEqual(outputs_assisted.tolist(), outputs_tti.tolist()) - def test_model_kwarg_assisted_decoding_encoder_decoder(self): - """ - Tests that the following scenario is compatible with assisted generation: - 1. encoder-decoder main model - 2. encoder-decoder assistant model - 3. both have a custom input - (e.g. Whisper) - """ - - # PT-only test: TF doesn't support assisted decoding yet. - # Bart subclass with a kwarg that distorts the output - class FakeBart(BartForConditionalGeneration): - def forward(self, input_ids, past_key_values, foo=False, **kwargs): - outs = super().forward(input_ids, past_key_values=past_key_values, **kwargs) - if foo: - outs["logits"][:, :, :] = 0.0 - return outs - - def prepare_inputs_for_generation(self, *args, foo=False, encoder_outputs=None, **kwargs): - kwargs["encoder_outputs"] = encoder_outputs - inputs = super().prepare_inputs_for_generation(*args, **kwargs) - inputs["foo"] = foo - return inputs - - model = FakeBart.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration").to( - torch_device - ) - tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration") - - text = "Hello world" - tokenized_inputs = tokenizer([text], return_tensors="pt") - input_ids = tokenized_inputs.input_ids.to(torch_device) - - # Traditional way of generating text - outputs_normal = model.generate(input_ids) - self.assertEqual(outputs_normal.shape, (1, 20)) - - # Should be different with foo - outputs_foo = model.generate(input_ids, foo=True) - with self.assertRaises(AssertionError): - self.assertListEqual(outputs_foo.tolist(), outputs_normal.tolist()) - - # Assistant model - assistant = FakeBart.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration").to( - torch_device - ) - - # If assisted generation passes model_kwargs correctly, should be same as previous - outputs_assisted = model.generate( - input_ids, - foo=True, - assistant_model=assistant, - ) - self.assertListEqual(outputs_assisted.tolist(), outputs_foo.tolist()) - - # Check that passing encoder_outputs directly also works as expected - encoder_outputs = assistant.get_encoder()(input_ids) - - outputs_assisted = model.generate( - foo=True, - assistant_model=assistant, - encoder_outputs=encoder_outputs, - assistant_encoder_outputs=encoder_outputs, - ) - self.assertListEqual(outputs_assisted.tolist(), outputs_foo.tolist()) - - def test_assisted_decoding_encoder_decoder_shared_encoder(self): - """ - Tests that the following scenario is compatible with assisted generation: - 1. encoder-decoder main model - 2. decoder-only assistant model - 3. both have a custom input - (e.g. DistilWhisper) - """ - - # PT-only test: TF doesn't support assisted decoding yet. - # Bart subclass with a kwarg called foo that distorts the output - class FakeBartSeq2Seq(BartForConditionalGeneration): - def forward(self, input_ids, foo=False, **kwargs): - outs = super().forward(input_ids, **kwargs) - if foo: - outs["logits"][:, :, :] = 0.0 - return outs - - def prepare_inputs_for_generation(self, *args, foo=False, encoder_outputs=None, **kwargs): - kwargs["encoder_outputs"] = encoder_outputs - inputs = super().prepare_inputs_for_generation(*args, **kwargs) - inputs["foo"] = foo - return inputs - - class FakeBartCausalLM(BartForCausalLM): - def forward(self, input_ids, attention_mask, past_key_values, foo=False, **kwargs): - outs = super().forward(input_ids, attention_mask, past_key_values=past_key_values, **kwargs) - if foo: - outs["logits"][:, :, :] = 0.0 - return outs - - def prepare_inputs_for_generation(self, *args, foo=False, encoder_outputs=None, **kwargs): - kwargs["encoder_outputs"] = encoder_outputs - inputs = super().prepare_inputs_for_generation(*args, **kwargs) - inputs["foo"] = foo - return inputs - - model = FakeBartSeq2Seq.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration").to( - torch_device - ) - tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration") - - text = "Hello world" - tokenized_inputs = tokenizer([text], return_tensors="pt") - input_ids = tokenized_inputs.input_ids.to(torch_device) - - # Traditional way of generating text - outputs_normal = model.generate(input_ids) - self.assertEqual(outputs_normal.shape, (1, 20)) - - # Should be different with foo - outputs_foo = model.generate(input_ids, foo=True) - with self.assertRaises(AssertionError): - self.assertListEqual(outputs_foo.tolist(), outputs_normal.tolist()) - - # Assistant model - assistant = FakeBartCausalLM.from_pretrained( - "hf-internal-testing/tiny-random-BartForConditionalGeneration" - ).to(torch_device) - - # If assisted generation passes model_kwargs correctly, should be same as previous - outputs_assisted = model.generate( - input_ids, - foo=True, - assistant_model=assistant, - ) - self.assertListEqual(outputs_assisted.tolist(), outputs_foo.tolist()) - - # Check that passing encoder_outputs directly also works as expected - encoder_outputs = model.get_encoder()(input_ids) - - outputs_assisted = model.generate( - foo=True, - assistant_model=assistant, - encoder_outputs=encoder_outputs, - ) - self.assertListEqual(outputs_assisted.tolist(), outputs_foo.tolist()) - def test_assisted_decoding_num_assistant_tokens_heuristic_schedule(self): # This test ensures that the assisted generation num_assistant_tokens 'heuristic' schedule works properly.