Generate: general test for decoder-only generation from inputs_embeds (#25687)
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
@@ -1750,6 +1750,56 @@ class GenerationTesterMixin:
|
|||||||
past_kv[i][1].shape, (batch_size, num_attention_heads, seq_length, per_head_embed_dim)
|
past_kv[i][1].shape, (batch_size, num_attention_heads, seq_length, per_head_embed_dim)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_generate_from_inputs_embeds_decoder_only(self):
|
||||||
|
# When supported, tests that the decoder model can generate from `inputs_embeds` instead of `input_ids`
|
||||||
|
# if fails, you should probably update the `prepare_inputs_for_generation` function
|
||||||
|
for model_class in self.all_generative_model_classes:
|
||||||
|
config, input_ids, _, _ = self._get_input_ids_and_config()
|
||||||
|
|
||||||
|
# Ignore:
|
||||||
|
# a) eos (to always output 20 tokens) and pad (so we don't try to infer the attn mask from the input_ids,
|
||||||
|
# which would cause a mismatch),
|
||||||
|
config.pad_token_id = config.eos_token_id = -1
|
||||||
|
# b) embedding scaling, the scaling factor applied after embeding from input_ids (requires knowledge of the
|
||||||
|
# variable that holds the scaling factor, which is model-dependent)
|
||||||
|
if hasattr(config, "scale_embedding"):
|
||||||
|
config.scale_embedding = False
|
||||||
|
|
||||||
|
# This test is for decoder-only models (encoder-decoder models have native input embeddings support in the
|
||||||
|
# decoder)
|
||||||
|
if config.is_encoder_decoder:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Skip models without explicit support
|
||||||
|
model = model_class(config).to(torch_device).eval()
|
||||||
|
if "inputs_embeds" not in inspect.signature(model.prepare_inputs_for_generation).parameters.keys():
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Traditional way of generating text
|
||||||
|
outputs_from_ids = model.generate(input_ids)
|
||||||
|
self.assertEqual(outputs_from_ids.shape, (2, 20))
|
||||||
|
|
||||||
|
# Same thing, but from input embeddings (`input_ids` is passed so the prompt is present in the output)
|
||||||
|
inputs_embeds = model.get_input_embeddings()(input_ids)
|
||||||
|
outputs_from_embeds = model.generate(input_ids, inputs_embeds=inputs_embeds)
|
||||||
|
self.assertListEqual(outputs_from_ids.tolist(), outputs_from_embeds.tolist())
|
||||||
|
|
||||||
|
# But if we pass different inputs_embeds, we should get different outputs
|
||||||
|
torch.manual_seed(0)
|
||||||
|
random_embeds = torch.rand_like(inputs_embeds)
|
||||||
|
outputs_from_rand_embeds = model.generate(input_ids, inputs_embeds=random_embeds)
|
||||||
|
with self.assertRaises(AssertionError):
|
||||||
|
self.assertListEqual(outputs_from_rand_embeds.tolist(), outputs_from_embeds.tolist())
|
||||||
|
|
||||||
|
# input_ids is not a required input -- if we don't pass it, the newly generated tokens will be the same
|
||||||
|
outputs_from_embeds_wo_ids = model.generate(
|
||||||
|
inputs_embeds=inputs_embeds, max_new_tokens=20 - inputs_embeds.shape[1]
|
||||||
|
)
|
||||||
|
self.assertListEqual(
|
||||||
|
outputs_from_embeds[:, inputs_embeds.shape[1] :].tolist(),
|
||||||
|
outputs_from_embeds_wo_ids[:, 1:].tolist(),
|
||||||
|
)
|
||||||
|
|
||||||
def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1):
|
def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1):
|
||||||
batch_size, seq_length = input_ids.shape
|
batch_size, seq_length = input_ids.shape
|
||||||
num_sequences_in_output = batch_size * num_return_sequences
|
num_sequences_in_output = batch_size * num_return_sequences
|
||||||
@@ -2773,42 +2823,6 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
|||||||
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
|
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
|
||||||
self.assertTrue(expectation == len(generated_tokens[0]))
|
self.assertTrue(expectation == len(generated_tokens[0]))
|
||||||
|
|
||||||
def test_generate_from_inputs_embeds_decoder_only(self):
|
|
||||||
# PT-only test: TF doesn't have a model with support to generate from input embeds (yet ;))
|
|
||||||
# Note: the model must support generation from input embeddings
|
|
||||||
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
|
||||||
model.config.pad_token_id = tokenizer.eos_token_id
|
|
||||||
|
|
||||||
text = "Hello world"
|
|
||||||
tokenized_inputs = tokenizer([text, text], return_tensors="pt")
|
|
||||||
input_ids = tokenized_inputs.input_ids.to(torch_device)
|
|
||||||
|
|
||||||
# Traditional way of generating text
|
|
||||||
outputs_from_ids = model.generate(input_ids)
|
|
||||||
self.assertEqual(outputs_from_ids.shape, (2, 20))
|
|
||||||
|
|
||||||
# Same thing, but from input embeddings
|
|
||||||
inputs_embeds = model.transformer.wte(input_ids)
|
|
||||||
outputs_from_embeds = model.generate(input_ids, inputs_embeds=inputs_embeds)
|
|
||||||
self.assertListEqual(outputs_from_ids.tolist(), outputs_from_embeds.tolist())
|
|
||||||
|
|
||||||
# But if we pass different inputs_embeds, we should get different outputs
|
|
||||||
torch.manual_seed(0)
|
|
||||||
random_embeds = torch.rand_like(inputs_embeds)
|
|
||||||
outputs_from_rand_embeds = model.generate(input_ids, inputs_embeds=random_embeds)
|
|
||||||
with self.assertRaises(AssertionError):
|
|
||||||
self.assertListEqual(outputs_from_rand_embeds.tolist(), outputs_from_embeds.tolist())
|
|
||||||
|
|
||||||
# input_ids is not a required input -- if we don't pass it, the newly generated tokens will be the same
|
|
||||||
outputs_from_embeds_wo_ids = model.generate(
|
|
||||||
inputs_embeds=inputs_embeds, max_new_tokens=20 - inputs_embeds.shape[1]
|
|
||||||
)
|
|
||||||
self.assertListEqual(
|
|
||||||
outputs_from_embeds[:, inputs_embeds.shape[1] :].tolist(),
|
|
||||||
outputs_from_embeds_wo_ids[:, 1:].tolist(),
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_model_kwarg_encoder_signature_filtering(self):
|
def test_model_kwarg_encoder_signature_filtering(self):
|
||||||
# Has TF equivalent: ample use of framework-specific code
|
# Has TF equivalent: ample use of framework-specific code
|
||||||
bart_tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
|
bart_tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||||
|
|||||||
Reference in New Issue
Block a user