In assisted decoding, pass model_kwargs to model's forward call (fix prepare_input_for_generation in all models) (#25242)

* In assisted decoding, pass model_kwargs to model's forward call

Previously, assisted decoding would ignore any additional kwargs
that it doesn't explicitly handle. This was inconsistent with other
generation methods, which pass the model_kwargs through
prepare_inputs_for_generation and forward the returned dict to the
model's forward call.

The prepare_inputs_for_generation method needs to be amended in all
models, as previously it only kept the last input ID when a past_key_values
was passed.

* Improve variable names in _extend_attention_mask

* Refactor extending token_type_ids into a function

* Replace deepcopy with copy to optimize performance

* Update new persimmon model with llama changes for assisted generation

* Update new mistral model for assisted generation with prepare_inputs_for_generation

* Update position_ids creation in falcon prepare_inputs_for_generation to support assisted generation
This commit is contained in:
Billy Bradley
2023-10-11 12:18:42 +01:00
committed by GitHub
parent 1e3c9ddacc
commit dcc49d8a7e
63 changed files with 911 additions and 179 deletions

View File

@@ -2906,3 +2906,89 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
model.generation_config.max_length = 10
model.generate(input_ids)
self.assertEqual(len(warning_list), 0)
def test_model_kwarg_assisted_decoding_decoder_only(self):
# PT-only test: TF doesn't support assisted decoding yet.
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
model.config.pad_token_id = tokenizer.eos_token_id
text = "Hello world"
tokenized_inputs = tokenizer([text], return_tensors="pt")
input_ids = tokenized_inputs.input_ids.to(torch_device)
# Traditional way of generating text
outputs_normal = model.generate(input_ids)
self.assertEqual(outputs_normal.shape, (1, 20))
# Should be different with token_type_ids
outputs_tti = model.generate(
input_ids,
token_type_ids=torch.zeros(input_ids.shape, dtype=torch.long).to(torch_device),
)
with self.assertRaises(AssertionError):
self.assertListEqual(outputs_tti.tolist(), outputs_normal.tolist())
# Assistant model
assistant = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
assistant.config.pad_token_id = tokenizer.eos_token_id
# If assisted generation passes model_kwargs correctly, should be same as previous
outputs_assisted = model.generate(
input_ids,
token_type_ids=torch.zeros(input_ids.shape, dtype=torch.long).to(torch_device),
assistant_model=assistant,
)
self.assertListEqual(outputs_assisted.tolist(), outputs_tti.tolist())
def test_model_kwarg_assisted_decoding_encoder_decoder(self):
# PT-only test: TF doesn't support assisted decoding yet.
# Bart subclass with a kwarg that distorts the output
class FakeBart(BartForConditionalGeneration):
def forward(self, input_ids, foo=False, **kwargs):
outs = super().forward(input_ids, **kwargs)
if foo:
outs["logits"][:, :, :] = 0.0
return outs
def prepare_inputs_for_generation(self, *args, foo=False, **kwargs):
inputs = super().prepare_inputs_for_generation(*args, **kwargs)
inputs["foo"] = foo
return inputs
model = FakeBart.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration").to(
torch_device
)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration")
text = "Hello world"
tokenized_inputs = tokenizer([text], return_tensors="pt")
input_ids = tokenized_inputs.input_ids.to(torch_device)
# Traditional way of generating text
outputs_normal = model.generate(input_ids)
self.assertEqual(outputs_normal.shape, (1, 20))
# Should be different with foo
outputs_foo = model.generate(
input_ids,
foo=True,
)
with self.assertRaises(AssertionError):
self.assertListEqual(outputs_foo.tolist(), outputs_normal.tolist())
# Assistant model
assistant = AutoModelForSeq2SeqLM.from_pretrained(
"hf-internal-testing/tiny-random-BartForConditionalGeneration"
).to(torch_device)
# If assisted generation passes model_kwargs correctly, should be same as previous
outputs_assisted = model.generate(
input_ids,
foo=True,
assistant_model=assistant,
)
self.assertListEqual(outputs_assisted.tolist(), outputs_foo.tolist())