Generate: general test for decoder-only generation from inputs_embeds (#25687)
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
@@ -1750,6 +1750,56 @@ class GenerationTesterMixin:
|
||||
past_kv[i][1].shape, (batch_size, num_attention_heads, seq_length, per_head_embed_dim)
|
||||
)
|
||||
|
||||
def test_generate_from_inputs_embeds_decoder_only(self):
|
||||
# When supported, tests that the decoder model can generate from `inputs_embeds` instead of `input_ids`
|
||||
# if fails, you should probably update the `prepare_inputs_for_generation` function
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, input_ids, _, _ = self._get_input_ids_and_config()
|
||||
|
||||
# Ignore:
|
||||
# a) eos (to always output 20 tokens) and pad (so we don't try to infer the attn mask from the input_ids,
|
||||
# which would cause a mismatch),
|
||||
config.pad_token_id = config.eos_token_id = -1
|
||||
# b) embedding scaling, the scaling factor applied after embeding from input_ids (requires knowledge of the
|
||||
# variable that holds the scaling factor, which is model-dependent)
|
||||
if hasattr(config, "scale_embedding"):
|
||||
config.scale_embedding = False
|
||||
|
||||
# This test is for decoder-only models (encoder-decoder models have native input embeddings support in the
|
||||
# decoder)
|
||||
if config.is_encoder_decoder:
|
||||
continue
|
||||
|
||||
# Skip models without explicit support
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
if "inputs_embeds" not in inspect.signature(model.prepare_inputs_for_generation).parameters.keys():
|
||||
continue
|
||||
|
||||
# Traditional way of generating text
|
||||
outputs_from_ids = model.generate(input_ids)
|
||||
self.assertEqual(outputs_from_ids.shape, (2, 20))
|
||||
|
||||
# Same thing, but from input embeddings (`input_ids` is passed so the prompt is present in the output)
|
||||
inputs_embeds = model.get_input_embeddings()(input_ids)
|
||||
outputs_from_embeds = model.generate(input_ids, inputs_embeds=inputs_embeds)
|
||||
self.assertListEqual(outputs_from_ids.tolist(), outputs_from_embeds.tolist())
|
||||
|
||||
# But if we pass different inputs_embeds, we should get different outputs
|
||||
torch.manual_seed(0)
|
||||
random_embeds = torch.rand_like(inputs_embeds)
|
||||
outputs_from_rand_embeds = model.generate(input_ids, inputs_embeds=random_embeds)
|
||||
with self.assertRaises(AssertionError):
|
||||
self.assertListEqual(outputs_from_rand_embeds.tolist(), outputs_from_embeds.tolist())
|
||||
|
||||
# input_ids is not a required input -- if we don't pass it, the newly generated tokens will be the same
|
||||
outputs_from_embeds_wo_ids = model.generate(
|
||||
inputs_embeds=inputs_embeds, max_new_tokens=20 - inputs_embeds.shape[1]
|
||||
)
|
||||
self.assertListEqual(
|
||||
outputs_from_embeds[:, inputs_embeds.shape[1] :].tolist(),
|
||||
outputs_from_embeds_wo_ids[:, 1:].tolist(),
|
||||
)
|
||||
|
||||
def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1):
|
||||
batch_size, seq_length = input_ids.shape
|
||||
num_sequences_in_output = batch_size * num_return_sequences
|
||||
@@ -2773,42 +2823,6 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
|
||||
self.assertTrue(expectation == len(generated_tokens[0]))
|
||||
|
||||
def test_generate_from_inputs_embeds_decoder_only(self):
|
||||
# PT-only test: TF doesn't have a model with support to generate from input embeds (yet ;))
|
||||
# Note: the model must support generation from input embeddings
|
||||
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
model.config.pad_token_id = tokenizer.eos_token_id
|
||||
|
||||
text = "Hello world"
|
||||
tokenized_inputs = tokenizer([text, text], return_tensors="pt")
|
||||
input_ids = tokenized_inputs.input_ids.to(torch_device)
|
||||
|
||||
# Traditional way of generating text
|
||||
outputs_from_ids = model.generate(input_ids)
|
||||
self.assertEqual(outputs_from_ids.shape, (2, 20))
|
||||
|
||||
# Same thing, but from input embeddings
|
||||
inputs_embeds = model.transformer.wte(input_ids)
|
||||
outputs_from_embeds = model.generate(input_ids, inputs_embeds=inputs_embeds)
|
||||
self.assertListEqual(outputs_from_ids.tolist(), outputs_from_embeds.tolist())
|
||||
|
||||
# But if we pass different inputs_embeds, we should get different outputs
|
||||
torch.manual_seed(0)
|
||||
random_embeds = torch.rand_like(inputs_embeds)
|
||||
outputs_from_rand_embeds = model.generate(input_ids, inputs_embeds=random_embeds)
|
||||
with self.assertRaises(AssertionError):
|
||||
self.assertListEqual(outputs_from_rand_embeds.tolist(), outputs_from_embeds.tolist())
|
||||
|
||||
# input_ids is not a required input -- if we don't pass it, the newly generated tokens will be the same
|
||||
outputs_from_embeds_wo_ids = model.generate(
|
||||
inputs_embeds=inputs_embeds, max_new_tokens=20 - inputs_embeds.shape[1]
|
||||
)
|
||||
self.assertListEqual(
|
||||
outputs_from_embeds[:, inputs_embeds.shape[1] :].tolist(),
|
||||
outputs_from_embeds_wo_ids[:, 1:].tolist(),
|
||||
)
|
||||
|
||||
def test_model_kwarg_encoder_signature_filtering(self):
|
||||
# Has TF equivalent: ample use of framework-specific code
|
||||
bart_tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||
|
||||
Reference in New Issue
Block a user