Whisper: remove redundant assisted generation tests (#34814)
* remove redundant test * delete another test * revert default max_length * (wrong place, moving)
This commit is contained in:
@@ -63,7 +63,6 @@ if is_torch_available():
|
||||
AutoModelForVision2Seq,
|
||||
AutoProcessor,
|
||||
AutoTokenizer,
|
||||
BartForCausalLM,
|
||||
BartForConditionalGeneration,
|
||||
BartTokenizer,
|
||||
GPT2LMHeadModel,
|
||||
@@ -3629,150 +3628,6 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
)
|
||||
self.assertListEqual(outputs_assisted.tolist(), outputs_tti.tolist())
|
||||
|
||||
def test_model_kwarg_assisted_decoding_encoder_decoder(self):
|
||||
"""
|
||||
Tests that the following scenario is compatible with assisted generation:
|
||||
1. encoder-decoder main model
|
||||
2. encoder-decoder assistant model
|
||||
3. both have a custom input
|
||||
(e.g. Whisper)
|
||||
"""
|
||||
|
||||
# PT-only test: TF doesn't support assisted decoding yet.
|
||||
# Bart subclass with a kwarg that distorts the output
|
||||
class FakeBart(BartForConditionalGeneration):
|
||||
def forward(self, input_ids, past_key_values, foo=False, **kwargs):
|
||||
outs = super().forward(input_ids, past_key_values=past_key_values, **kwargs)
|
||||
if foo:
|
||||
outs["logits"][:, :, :] = 0.0
|
||||
return outs
|
||||
|
||||
def prepare_inputs_for_generation(self, *args, foo=False, encoder_outputs=None, **kwargs):
|
||||
kwargs["encoder_outputs"] = encoder_outputs
|
||||
inputs = super().prepare_inputs_for_generation(*args, **kwargs)
|
||||
inputs["foo"] = foo
|
||||
return inputs
|
||||
|
||||
model = FakeBart.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration").to(
|
||||
torch_device
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration")
|
||||
|
||||
text = "Hello world"
|
||||
tokenized_inputs = tokenizer([text], return_tensors="pt")
|
||||
input_ids = tokenized_inputs.input_ids.to(torch_device)
|
||||
|
||||
# Traditional way of generating text
|
||||
outputs_normal = model.generate(input_ids)
|
||||
self.assertEqual(outputs_normal.shape, (1, 20))
|
||||
|
||||
# Should be different with foo
|
||||
outputs_foo = model.generate(input_ids, foo=True)
|
||||
with self.assertRaises(AssertionError):
|
||||
self.assertListEqual(outputs_foo.tolist(), outputs_normal.tolist())
|
||||
|
||||
# Assistant model
|
||||
assistant = FakeBart.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration").to(
|
||||
torch_device
|
||||
)
|
||||
|
||||
# If assisted generation passes model_kwargs correctly, should be same as previous
|
||||
outputs_assisted = model.generate(
|
||||
input_ids,
|
||||
foo=True,
|
||||
assistant_model=assistant,
|
||||
)
|
||||
self.assertListEqual(outputs_assisted.tolist(), outputs_foo.tolist())
|
||||
|
||||
# Check that passing encoder_outputs directly also works as expected
|
||||
encoder_outputs = assistant.get_encoder()(input_ids)
|
||||
|
||||
outputs_assisted = model.generate(
|
||||
foo=True,
|
||||
assistant_model=assistant,
|
||||
encoder_outputs=encoder_outputs,
|
||||
assistant_encoder_outputs=encoder_outputs,
|
||||
)
|
||||
self.assertListEqual(outputs_assisted.tolist(), outputs_foo.tolist())
|
||||
|
||||
def test_assisted_decoding_encoder_decoder_shared_encoder(self):
|
||||
"""
|
||||
Tests that the following scenario is compatible with assisted generation:
|
||||
1. encoder-decoder main model
|
||||
2. decoder-only assistant model
|
||||
3. both have a custom input
|
||||
(e.g. DistilWhisper)
|
||||
"""
|
||||
|
||||
# PT-only test: TF doesn't support assisted decoding yet.
|
||||
# Bart subclass with a kwarg called foo that distorts the output
|
||||
class FakeBartSeq2Seq(BartForConditionalGeneration):
|
||||
def forward(self, input_ids, foo=False, **kwargs):
|
||||
outs = super().forward(input_ids, **kwargs)
|
||||
if foo:
|
||||
outs["logits"][:, :, :] = 0.0
|
||||
return outs
|
||||
|
||||
def prepare_inputs_for_generation(self, *args, foo=False, encoder_outputs=None, **kwargs):
|
||||
kwargs["encoder_outputs"] = encoder_outputs
|
||||
inputs = super().prepare_inputs_for_generation(*args, **kwargs)
|
||||
inputs["foo"] = foo
|
||||
return inputs
|
||||
|
||||
class FakeBartCausalLM(BartForCausalLM):
|
||||
def forward(self, input_ids, attention_mask, past_key_values, foo=False, **kwargs):
|
||||
outs = super().forward(input_ids, attention_mask, past_key_values=past_key_values, **kwargs)
|
||||
if foo:
|
||||
outs["logits"][:, :, :] = 0.0
|
||||
return outs
|
||||
|
||||
def prepare_inputs_for_generation(self, *args, foo=False, encoder_outputs=None, **kwargs):
|
||||
kwargs["encoder_outputs"] = encoder_outputs
|
||||
inputs = super().prepare_inputs_for_generation(*args, **kwargs)
|
||||
inputs["foo"] = foo
|
||||
return inputs
|
||||
|
||||
model = FakeBartSeq2Seq.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration").to(
|
||||
torch_device
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration")
|
||||
|
||||
text = "Hello world"
|
||||
tokenized_inputs = tokenizer([text], return_tensors="pt")
|
||||
input_ids = tokenized_inputs.input_ids.to(torch_device)
|
||||
|
||||
# Traditional way of generating text
|
||||
outputs_normal = model.generate(input_ids)
|
||||
self.assertEqual(outputs_normal.shape, (1, 20))
|
||||
|
||||
# Should be different with foo
|
||||
outputs_foo = model.generate(input_ids, foo=True)
|
||||
with self.assertRaises(AssertionError):
|
||||
self.assertListEqual(outputs_foo.tolist(), outputs_normal.tolist())
|
||||
|
||||
# Assistant model
|
||||
assistant = FakeBartCausalLM.from_pretrained(
|
||||
"hf-internal-testing/tiny-random-BartForConditionalGeneration"
|
||||
).to(torch_device)
|
||||
|
||||
# If assisted generation passes model_kwargs correctly, should be same as previous
|
||||
outputs_assisted = model.generate(
|
||||
input_ids,
|
||||
foo=True,
|
||||
assistant_model=assistant,
|
||||
)
|
||||
self.assertListEqual(outputs_assisted.tolist(), outputs_foo.tolist())
|
||||
|
||||
# Check that passing encoder_outputs directly also works as expected
|
||||
encoder_outputs = model.get_encoder()(input_ids)
|
||||
|
||||
outputs_assisted = model.generate(
|
||||
foo=True,
|
||||
assistant_model=assistant,
|
||||
encoder_outputs=encoder_outputs,
|
||||
)
|
||||
self.assertListEqual(outputs_assisted.tolist(), outputs_foo.tolist())
|
||||
|
||||
def test_assisted_decoding_num_assistant_tokens_heuristic_schedule(self):
|
||||
# This test ensures that the assisted generation num_assistant_tokens 'heuristic' schedule works properly.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user