Whisper: remove redundant assisted generation tests (#34814)
* remove redundant test * delete another test * revert default max_length * (wrong place, moving)
This commit is contained in:
@@ -124,7 +124,7 @@ class AssistedCandidateGenerator(CandidateGenerator):
|
||||
# Prepare the kwargs for the assistant model
|
||||
assistant_kwargs = {}
|
||||
for key, value in model_kwargs.items(): # deepcopy crashes if we attempt to copy encoder outputs with grads
|
||||
if key not in ("encoder_outputs", "assistant_encoder_outputs", "past_key_values"):
|
||||
if key not in ("encoder_outputs", "past_key_values"):
|
||||
assistant_kwargs[key] = (
|
||||
value.detach().to(device) if isinstance(value, torch.Tensor) else copy.deepcopy(value)
|
||||
)
|
||||
@@ -133,9 +133,8 @@ class AssistedCandidateGenerator(CandidateGenerator):
|
||||
if "logits_to_keep" in assistant_kwargs.keys() and not assistant_model._supports_logits_to_keep():
|
||||
del assistant_kwargs["logits_to_keep"]
|
||||
|
||||
if "assistant_encoder_outputs" in model_kwargs:
|
||||
assistant_kwargs["encoder_outputs"] = model_kwargs["assistant_encoder_outputs"]
|
||||
elif assistant_model.config.is_encoder_decoder:
|
||||
# If the assistant is an encoder-decoder model, assume the encoder is different on the assistant.
|
||||
if assistant_model.config.is_encoder_decoder:
|
||||
inputs_tensor, model_input_name, assistant_kwargs = assistant_model._prepare_model_inputs(
|
||||
inputs_tensor, assistant_model.generation_config.bos_token_id, assistant_kwargs
|
||||
)
|
||||
|
||||
@@ -396,7 +396,7 @@ class FlaxGenerationMixin:
|
||||
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
|
||||
)
|
||||
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
|
||||
else: # by default let's always generate 10 new tokens
|
||||
else: # by default let's always generate 20 new tokens
|
||||
if generation_config.max_length == GenerationConfig().max_length:
|
||||
generation_config.max_length = generation_config.max_length + input_ids_seq_length
|
||||
max_position_embeddings = getattr(self.config, "max_position_embeddings", None)
|
||||
|
||||
@@ -1385,10 +1385,6 @@ class GenerationMixin:
|
||||
decoder_model_args = set(inspect.signature(decoder.forward).parameters)
|
||||
model_args |= {f"decoder_{x}" for x in decoder_model_args}
|
||||
|
||||
# allow assistant_encoder_outputs to be passed if we're doing assisted generating
|
||||
if "assistant_encoder_outputs" in model_kwargs:
|
||||
model_args |= {"assistant_encoder_outputs"}
|
||||
|
||||
for key, value in model_kwargs.items():
|
||||
if value is not None and key not in model_args:
|
||||
unused_model_args.append(key)
|
||||
|
||||
@@ -63,7 +63,6 @@ if is_torch_available():
|
||||
AutoModelForVision2Seq,
|
||||
AutoProcessor,
|
||||
AutoTokenizer,
|
||||
BartForCausalLM,
|
||||
BartForConditionalGeneration,
|
||||
BartTokenizer,
|
||||
GPT2LMHeadModel,
|
||||
@@ -3629,150 +3628,6 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
)
|
||||
self.assertListEqual(outputs_assisted.tolist(), outputs_tti.tolist())
|
||||
|
||||
def test_model_kwarg_assisted_decoding_encoder_decoder(self):
|
||||
"""
|
||||
Tests that the following scenario is compatible with assisted generation:
|
||||
1. encoder-decoder main model
|
||||
2. encoder-decoder assistant model
|
||||
3. both have a custom input
|
||||
(e.g. Whisper)
|
||||
"""
|
||||
|
||||
# PT-only test: TF doesn't support assisted decoding yet.
|
||||
# Bart subclass with a kwarg that distorts the output
|
||||
class FakeBart(BartForConditionalGeneration):
|
||||
def forward(self, input_ids, past_key_values, foo=False, **kwargs):
|
||||
outs = super().forward(input_ids, past_key_values=past_key_values, **kwargs)
|
||||
if foo:
|
||||
outs["logits"][:, :, :] = 0.0
|
||||
return outs
|
||||
|
||||
def prepare_inputs_for_generation(self, *args, foo=False, encoder_outputs=None, **kwargs):
|
||||
kwargs["encoder_outputs"] = encoder_outputs
|
||||
inputs = super().prepare_inputs_for_generation(*args, **kwargs)
|
||||
inputs["foo"] = foo
|
||||
return inputs
|
||||
|
||||
model = FakeBart.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration").to(
|
||||
torch_device
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration")
|
||||
|
||||
text = "Hello world"
|
||||
tokenized_inputs = tokenizer([text], return_tensors="pt")
|
||||
input_ids = tokenized_inputs.input_ids.to(torch_device)
|
||||
|
||||
# Traditional way of generating text
|
||||
outputs_normal = model.generate(input_ids)
|
||||
self.assertEqual(outputs_normal.shape, (1, 20))
|
||||
|
||||
# Should be different with foo
|
||||
outputs_foo = model.generate(input_ids, foo=True)
|
||||
with self.assertRaises(AssertionError):
|
||||
self.assertListEqual(outputs_foo.tolist(), outputs_normal.tolist())
|
||||
|
||||
# Assistant model
|
||||
assistant = FakeBart.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration").to(
|
||||
torch_device
|
||||
)
|
||||
|
||||
# If assisted generation passes model_kwargs correctly, should be same as previous
|
||||
outputs_assisted = model.generate(
|
||||
input_ids,
|
||||
foo=True,
|
||||
assistant_model=assistant,
|
||||
)
|
||||
self.assertListEqual(outputs_assisted.tolist(), outputs_foo.tolist())
|
||||
|
||||
# Check that passing encoder_outputs directly also works as expected
|
||||
encoder_outputs = assistant.get_encoder()(input_ids)
|
||||
|
||||
outputs_assisted = model.generate(
|
||||
foo=True,
|
||||
assistant_model=assistant,
|
||||
encoder_outputs=encoder_outputs,
|
||||
assistant_encoder_outputs=encoder_outputs,
|
||||
)
|
||||
self.assertListEqual(outputs_assisted.tolist(), outputs_foo.tolist())
|
||||
|
||||
def test_assisted_decoding_encoder_decoder_shared_encoder(self):
|
||||
"""
|
||||
Tests that the following scenario is compatible with assisted generation:
|
||||
1. encoder-decoder main model
|
||||
2. decoder-only assistant model
|
||||
3. both have a custom input
|
||||
(e.g. DistilWhisper)
|
||||
"""
|
||||
|
||||
# PT-only test: TF doesn't support assisted decoding yet.
|
||||
# Bart subclass with a kwarg called foo that distorts the output
|
||||
class FakeBartSeq2Seq(BartForConditionalGeneration):
|
||||
def forward(self, input_ids, foo=False, **kwargs):
|
||||
outs = super().forward(input_ids, **kwargs)
|
||||
if foo:
|
||||
outs["logits"][:, :, :] = 0.0
|
||||
return outs
|
||||
|
||||
def prepare_inputs_for_generation(self, *args, foo=False, encoder_outputs=None, **kwargs):
|
||||
kwargs["encoder_outputs"] = encoder_outputs
|
||||
inputs = super().prepare_inputs_for_generation(*args, **kwargs)
|
||||
inputs["foo"] = foo
|
||||
return inputs
|
||||
|
||||
class FakeBartCausalLM(BartForCausalLM):
|
||||
def forward(self, input_ids, attention_mask, past_key_values, foo=False, **kwargs):
|
||||
outs = super().forward(input_ids, attention_mask, past_key_values=past_key_values, **kwargs)
|
||||
if foo:
|
||||
outs["logits"][:, :, :] = 0.0
|
||||
return outs
|
||||
|
||||
def prepare_inputs_for_generation(self, *args, foo=False, encoder_outputs=None, **kwargs):
|
||||
kwargs["encoder_outputs"] = encoder_outputs
|
||||
inputs = super().prepare_inputs_for_generation(*args, **kwargs)
|
||||
inputs["foo"] = foo
|
||||
return inputs
|
||||
|
||||
model = FakeBartSeq2Seq.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration").to(
|
||||
torch_device
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration")
|
||||
|
||||
text = "Hello world"
|
||||
tokenized_inputs = tokenizer([text], return_tensors="pt")
|
||||
input_ids = tokenized_inputs.input_ids.to(torch_device)
|
||||
|
||||
# Traditional way of generating text
|
||||
outputs_normal = model.generate(input_ids)
|
||||
self.assertEqual(outputs_normal.shape, (1, 20))
|
||||
|
||||
# Should be different with foo
|
||||
outputs_foo = model.generate(input_ids, foo=True)
|
||||
with self.assertRaises(AssertionError):
|
||||
self.assertListEqual(outputs_foo.tolist(), outputs_normal.tolist())
|
||||
|
||||
# Assistant model
|
||||
assistant = FakeBartCausalLM.from_pretrained(
|
||||
"hf-internal-testing/tiny-random-BartForConditionalGeneration"
|
||||
).to(torch_device)
|
||||
|
||||
# If assisted generation passes model_kwargs correctly, should be same as previous
|
||||
outputs_assisted = model.generate(
|
||||
input_ids,
|
||||
foo=True,
|
||||
assistant_model=assistant,
|
||||
)
|
||||
self.assertListEqual(outputs_assisted.tolist(), outputs_foo.tolist())
|
||||
|
||||
# Check that passing encoder_outputs directly also works as expected
|
||||
encoder_outputs = model.get_encoder()(input_ids)
|
||||
|
||||
outputs_assisted = model.generate(
|
||||
foo=True,
|
||||
assistant_model=assistant,
|
||||
encoder_outputs=encoder_outputs,
|
||||
)
|
||||
self.assertListEqual(outputs_assisted.tolist(), outputs_foo.tolist())
|
||||
|
||||
def test_assisted_decoding_num_assistant_tokens_heuristic_schedule(self):
|
||||
# This test ensures that the assisted generation num_assistant_tokens 'heuristic' schedule works properly.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user