Whisper: remove redundant assisted generation tests (#34814)

* remove redundant test

* delete another test

* revert default max_length

* (wrong place, moving)
This commit is contained in:
Joao Gante
2025-02-12 11:37:19 +00:00
committed by GitHub
parent 0cd5e2dfd0
commit 1cc7ca3295
4 changed files with 4 additions and 154 deletions

View File

@@ -124,7 +124,7 @@ class AssistedCandidateGenerator(CandidateGenerator):
# Prepare the kwargs for the assistant model
assistant_kwargs = {}
for key, value in model_kwargs.items(): # deepcopy crashes if we attempt to copy encoder outputs with grads
if key not in ("encoder_outputs", "assistant_encoder_outputs", "past_key_values"):
if key not in ("encoder_outputs", "past_key_values"):
assistant_kwargs[key] = (
value.detach().to(device) if isinstance(value, torch.Tensor) else copy.deepcopy(value)
)
@@ -133,9 +133,8 @@ class AssistedCandidateGenerator(CandidateGenerator):
if "logits_to_keep" in assistant_kwargs.keys() and not assistant_model._supports_logits_to_keep():
del assistant_kwargs["logits_to_keep"]
if "assistant_encoder_outputs" in model_kwargs:
assistant_kwargs["encoder_outputs"] = model_kwargs["assistant_encoder_outputs"]
elif assistant_model.config.is_encoder_decoder:
# If the assistant is an encoder-decoder model, assume the encoder is different on the assistant.
if assistant_model.config.is_encoder_decoder:
inputs_tensor, model_input_name, assistant_kwargs = assistant_model._prepare_model_inputs(
inputs_tensor, assistant_model.generation_config.bos_token_id, assistant_kwargs
)

View File

@@ -396,7 +396,7 @@ class FlaxGenerationMixin:
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
)
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
else: # by default let's always generate 10 new tokens
else: # by default let's always generate 20 new tokens
if generation_config.max_length == GenerationConfig().max_length:
generation_config.max_length = generation_config.max_length + input_ids_seq_length
max_position_embeddings = getattr(self.config, "max_position_embeddings", None)

View File

@@ -1385,10 +1385,6 @@ class GenerationMixin:
decoder_model_args = set(inspect.signature(decoder.forward).parameters)
model_args |= {f"decoder_{x}" for x in decoder_model_args}
# allow assistant_encoder_outputs to be passed if we're doing assisted generating
if "assistant_encoder_outputs" in model_kwargs:
model_args |= {"assistant_encoder_outputs"}
for key, value in model_kwargs.items():
if value is not None and key not in model_args:
unused_model_args.append(key)

View File

@@ -63,7 +63,6 @@ if is_torch_available():
AutoModelForVision2Seq,
AutoProcessor,
AutoTokenizer,
BartForCausalLM,
BartForConditionalGeneration,
BartTokenizer,
GPT2LMHeadModel,
@@ -3629,150 +3628,6 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
)
self.assertListEqual(outputs_assisted.tolist(), outputs_tti.tolist())
def test_model_kwarg_assisted_decoding_encoder_decoder(self):
"""
Tests that the following scenario is compatible with assisted generation:
1. encoder-decoder main model
2. encoder-decoder assistant model
3. both have a custom input
(e.g. Whisper)
"""
# PT-only test: TF doesn't support assisted decoding yet.
# Bart subclass with a kwarg that distorts the output
class FakeBart(BartForConditionalGeneration):
def forward(self, input_ids, past_key_values, foo=False, **kwargs):
outs = super().forward(input_ids, past_key_values=past_key_values, **kwargs)
if foo:
outs["logits"][:, :, :] = 0.0
return outs
def prepare_inputs_for_generation(self, *args, foo=False, encoder_outputs=None, **kwargs):
kwargs["encoder_outputs"] = encoder_outputs
inputs = super().prepare_inputs_for_generation(*args, **kwargs)
inputs["foo"] = foo
return inputs
model = FakeBart.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration").to(
torch_device
)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration")
text = "Hello world"
tokenized_inputs = tokenizer([text], return_tensors="pt")
input_ids = tokenized_inputs.input_ids.to(torch_device)
# Traditional way of generating text
outputs_normal = model.generate(input_ids)
self.assertEqual(outputs_normal.shape, (1, 20))
# Should be different with foo
outputs_foo = model.generate(input_ids, foo=True)
with self.assertRaises(AssertionError):
self.assertListEqual(outputs_foo.tolist(), outputs_normal.tolist())
# Assistant model
assistant = FakeBart.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration").to(
torch_device
)
# If assisted generation passes model_kwargs correctly, should be same as previous
outputs_assisted = model.generate(
input_ids,
foo=True,
assistant_model=assistant,
)
self.assertListEqual(outputs_assisted.tolist(), outputs_foo.tolist())
# Check that passing encoder_outputs directly also works as expected
encoder_outputs = assistant.get_encoder()(input_ids)
outputs_assisted = model.generate(
foo=True,
assistant_model=assistant,
encoder_outputs=encoder_outputs,
assistant_encoder_outputs=encoder_outputs,
)
self.assertListEqual(outputs_assisted.tolist(), outputs_foo.tolist())
def test_assisted_decoding_encoder_decoder_shared_encoder(self):
"""
Tests that the following scenario is compatible with assisted generation:
1. encoder-decoder main model
2. decoder-only assistant model
3. both have a custom input
(e.g. DistilWhisper)
"""
# PT-only test: TF doesn't support assisted decoding yet.
# Bart subclass with a kwarg called foo that distorts the output
class FakeBartSeq2Seq(BartForConditionalGeneration):
def forward(self, input_ids, foo=False, **kwargs):
outs = super().forward(input_ids, **kwargs)
if foo:
outs["logits"][:, :, :] = 0.0
return outs
def prepare_inputs_for_generation(self, *args, foo=False, encoder_outputs=None, **kwargs):
kwargs["encoder_outputs"] = encoder_outputs
inputs = super().prepare_inputs_for_generation(*args, **kwargs)
inputs["foo"] = foo
return inputs
class FakeBartCausalLM(BartForCausalLM):
def forward(self, input_ids, attention_mask, past_key_values, foo=False, **kwargs):
outs = super().forward(input_ids, attention_mask, past_key_values=past_key_values, **kwargs)
if foo:
outs["logits"][:, :, :] = 0.0
return outs
def prepare_inputs_for_generation(self, *args, foo=False, encoder_outputs=None, **kwargs):
kwargs["encoder_outputs"] = encoder_outputs
inputs = super().prepare_inputs_for_generation(*args, **kwargs)
inputs["foo"] = foo
return inputs
model = FakeBartSeq2Seq.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration").to(
torch_device
)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration")
text = "Hello world"
tokenized_inputs = tokenizer([text], return_tensors="pt")
input_ids = tokenized_inputs.input_ids.to(torch_device)
# Traditional way of generating text
outputs_normal = model.generate(input_ids)
self.assertEqual(outputs_normal.shape, (1, 20))
# Should be different with foo
outputs_foo = model.generate(input_ids, foo=True)
with self.assertRaises(AssertionError):
self.assertListEqual(outputs_foo.tolist(), outputs_normal.tolist())
# Assistant model
assistant = FakeBartCausalLM.from_pretrained(
"hf-internal-testing/tiny-random-BartForConditionalGeneration"
).to(torch_device)
# If assisted generation passes model_kwargs correctly, should be same as previous
outputs_assisted = model.generate(
input_ids,
foo=True,
assistant_model=assistant,
)
self.assertListEqual(outputs_assisted.tolist(), outputs_foo.tolist())
# Check that passing encoder_outputs directly also works as expected
encoder_outputs = model.get_encoder()(input_ids)
outputs_assisted = model.generate(
foo=True,
assistant_model=assistant,
encoder_outputs=encoder_outputs,
)
self.assertListEqual(outputs_assisted.tolist(), outputs_foo.tolist())
def test_assisted_decoding_num_assistant_tokens_heuristic_schedule(self):
# This test ensures that the assisted generation num_assistant_tokens 'heuristic' schedule works properly.