Generate: correct default model input creation for decoder-only models (#21580)
This commit is contained in:
@@ -2488,3 +2488,38 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
eos_token_id = [846, 198]
|
||||
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
|
||||
self.assertTrue(expectation == len(generated_tokens[0]))
|
||||
|
||||
def test_generate_from_inputs_embeds_decoder_only(self):
|
||||
# Note: the model must support generation from input embeddings
|
||||
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
model.config.pad_token_id = tokenizer.eos_token_id
|
||||
|
||||
text = "Hello world"
|
||||
tokenized_inputs = tokenizer([text, text], return_tensors="pt")
|
||||
input_ids = tokenized_inputs.input_ids.to(torch_device)
|
||||
|
||||
# Traditional way of generating text
|
||||
outputs_from_ids = model.generate(input_ids)
|
||||
self.assertEqual(outputs_from_ids.shape, (2, 20))
|
||||
|
||||
# Same thing, but from input embeddings
|
||||
inputs_embeds = model.transformer.wte(input_ids)
|
||||
outputs_from_embeds = model.generate(input_ids, inputs_embeds=inputs_embeds)
|
||||
self.assertListEqual(outputs_from_ids.tolist(), outputs_from_embeds.tolist())
|
||||
|
||||
# But if we pass different inputs_embeds, we should get different outputs
|
||||
torch.manual_seed(0)
|
||||
random_embeds = torch.rand_like(inputs_embeds)
|
||||
outputs_from_rand_embeds = model.generate(input_ids, inputs_embeds=random_embeds)
|
||||
with self.assertRaises(AssertionError):
|
||||
self.assertListEqual(outputs_from_rand_embeds.tolist(), outputs_from_embeds.tolist())
|
||||
|
||||
# input_ids is not a required input -- if we don't pass it, the newly generated tokens will be the same
|
||||
outputs_from_embeds_wo_ids = model.generate(
|
||||
inputs_embeds=inputs_embeds, max_new_tokens=20 - inputs_embeds.shape[1]
|
||||
)
|
||||
self.assertListEqual(
|
||||
outputs_from_embeds[:, inputs_embeds.shape[1] :].tolist(),
|
||||
outputs_from_embeds_wo_ids[:, 1:].tolist(),
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user