In assisted decoding, pass model_kwargs to model's forward call (fix prepare_input_for_generation in all models) (#25242)
* In assisted decoding, pass model_kwargs to model's forward call Previously, assisted decoding would ignore any additional kwargs that it doesn't explicitly handle. This was inconsistent with other generation methods, which pass the model_kwargs through prepare_inputs_for_generation and forward the returned dict to the model's forward call. The prepare_inputs_for_generation method needs to be amended in all models, as previously it only kept the last input ID when a past_key_values was passed. * Improve variable names in _extend_attention_mask * Refactor extending token_type_ids into a function * Replace deepcopy with copy to optimize performance * Update new persimmon model with llama changes for assisted generation * Update new mistral model for assisted generation with prepare_inputs_for_generation * Update position_ids creation in falcon prepare_inputs_for_generation to support assisted generation
This commit is contained in:
@@ -2906,3 +2906,89 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
model.generation_config.max_length = 10
|
||||
model.generate(input_ids)
|
||||
self.assertEqual(len(warning_list), 0)
|
||||
|
||||
def test_model_kwarg_assisted_decoding_decoder_only(self):
|
||||
# PT-only test: TF doesn't support assisted decoding yet.
|
||||
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
model.config.pad_token_id = tokenizer.eos_token_id
|
||||
|
||||
text = "Hello world"
|
||||
tokenized_inputs = tokenizer([text], return_tensors="pt")
|
||||
input_ids = tokenized_inputs.input_ids.to(torch_device)
|
||||
|
||||
# Traditional way of generating text
|
||||
outputs_normal = model.generate(input_ids)
|
||||
self.assertEqual(outputs_normal.shape, (1, 20))
|
||||
|
||||
# Should be different with token_type_ids
|
||||
outputs_tti = model.generate(
|
||||
input_ids,
|
||||
token_type_ids=torch.zeros(input_ids.shape, dtype=torch.long).to(torch_device),
|
||||
)
|
||||
with self.assertRaises(AssertionError):
|
||||
self.assertListEqual(outputs_tti.tolist(), outputs_normal.tolist())
|
||||
|
||||
# Assistant model
|
||||
assistant = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
|
||||
assistant.config.pad_token_id = tokenizer.eos_token_id
|
||||
|
||||
# If assisted generation passes model_kwargs correctly, should be same as previous
|
||||
outputs_assisted = model.generate(
|
||||
input_ids,
|
||||
token_type_ids=torch.zeros(input_ids.shape, dtype=torch.long).to(torch_device),
|
||||
assistant_model=assistant,
|
||||
)
|
||||
self.assertListEqual(outputs_assisted.tolist(), outputs_tti.tolist())
|
||||
|
||||
def test_model_kwarg_assisted_decoding_encoder_decoder(self):
|
||||
# PT-only test: TF doesn't support assisted decoding yet.
|
||||
# Bart subclass with a kwarg that distorts the output
|
||||
class FakeBart(BartForConditionalGeneration):
|
||||
def forward(self, input_ids, foo=False, **kwargs):
|
||||
outs = super().forward(input_ids, **kwargs)
|
||||
|
||||
if foo:
|
||||
outs["logits"][:, :, :] = 0.0
|
||||
|
||||
return outs
|
||||
|
||||
def prepare_inputs_for_generation(self, *args, foo=False, **kwargs):
|
||||
inputs = super().prepare_inputs_for_generation(*args, **kwargs)
|
||||
|
||||
inputs["foo"] = foo
|
||||
return inputs
|
||||
|
||||
model = FakeBart.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration").to(
|
||||
torch_device
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration")
|
||||
|
||||
text = "Hello world"
|
||||
tokenized_inputs = tokenizer([text], return_tensors="pt")
|
||||
input_ids = tokenized_inputs.input_ids.to(torch_device)
|
||||
|
||||
# Traditional way of generating text
|
||||
outputs_normal = model.generate(input_ids)
|
||||
self.assertEqual(outputs_normal.shape, (1, 20))
|
||||
|
||||
# Should be different with foo
|
||||
outputs_foo = model.generate(
|
||||
input_ids,
|
||||
foo=True,
|
||||
)
|
||||
with self.assertRaises(AssertionError):
|
||||
self.assertListEqual(outputs_foo.tolist(), outputs_normal.tolist())
|
||||
|
||||
# Assistant model
|
||||
assistant = AutoModelForSeq2SeqLM.from_pretrained(
|
||||
"hf-internal-testing/tiny-random-BartForConditionalGeneration"
|
||||
).to(torch_device)
|
||||
|
||||
# If assisted generation passes model_kwargs correctly, should be same as previous
|
||||
outputs_assisted = model.generate(
|
||||
input_ids,
|
||||
foo=True,
|
||||
assistant_model=assistant,
|
||||
)
|
||||
self.assertListEqual(outputs_assisted.tolist(), outputs_foo.tolist())
|
||||
|
||||
Reference in New Issue
Block a user