[WhisperForCausalLM] Add WhisperForCausalLM for speculative decoding (#27195)

* finish

* add tests

* fix all tests

* [Assistant Decoding] Add test

* fix more

* better

* finish

* Apply suggestions from code review

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* finish

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
Patrick von Platen
2023-11-01 16:01:53 +01:00
committed by GitHub
parent f9b4bea0a6
commit 391d14e810
10 changed files with 601 additions and 5 deletions

View File

@@ -43,6 +43,7 @@ if is_torch_available():
AutoModelForSpeechSeq2Seq,
AutoModelForVision2Seq,
AutoTokenizer,
BartForCausalLM,
BartForConditionalGeneration,
BartTokenizer,
GPT2LMHeadModel,
@@ -3010,3 +3011,63 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
assistant_encoder_outputs=encoder_outputs,
)
self.assertListEqual(outputs_assisted.tolist(), outputs_foo.tolist())
def test_assisted_decoding_encoder_decoder_shared_encoder(self):
# PT-only test: TF doesn't support assisted decoding yet.
# Bart subclass with a kwarg called foo that distorts the output
class FakeBart(BartForConditionalGeneration):
def forward(self, input_ids, foo=False, **kwargs):
outs = super().forward(input_ids, **kwargs)
if foo:
outs["logits"][:, :, :] = 0.0
return outs
def prepare_inputs_for_generation(self, *args, foo=False, encoder_outputs=None, **kwargs):
kwargs["encoder_outputs"] = encoder_outputs
inputs = super().prepare_inputs_for_generation(*args, **kwargs)
inputs["foo"] = foo
return inputs
model = FakeBart.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration").to(
torch_device
)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration")
text = "Hello world"
tokenized_inputs = tokenizer([text], return_tensors="pt")
input_ids = tokenized_inputs.input_ids.to(torch_device)
# Traditional way of generating text
outputs_normal = model.generate(input_ids)
self.assertEqual(outputs_normal.shape, (1, 20))
# Should be different with foo
outputs_foo = model.generate(input_ids, foo=True)
with self.assertRaises(AssertionError):
self.assertListEqual(outputs_foo.tolist(), outputs_normal.tolist())
# Assistant model
assistant = BartForCausalLM.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration").to(
torch_device
)
# If assisted generation passes model_kwargs correctly, should be same as previous
outputs_assisted = model.generate(
input_ids,
foo=True,
assistant_model=assistant,
)
self.assertListEqual(outputs_assisted.tolist(), outputs_foo.tolist())
# Check that passing encoder_outputs directly also works as expected
encoder_outputs = model.get_encoder()(input_ids)
outputs_assisted = model.generate(
foo=True,
assistant_model=assistant,
encoder_outputs=encoder_outputs,
)
self.assertListEqual(outputs_assisted.tolist(), outputs_foo.tolist())