Generate: remove most decoder-only LLMs prepare_inputs_for_generation (#33870)

This commit is contained in:
Joao Gante
2024-10-09 12:15:48 +01:00
committed by GitHub
parent cdee5285ca
commit 295a90cb40
68 changed files with 235 additions and 1457 deletions

View File

@@ -24,7 +24,7 @@ import numpy as np
import pytest
from parameterized import parameterized
from transformers import is_torch_available, pipeline, set_seed
from transformers import AutoConfig, is_torch_available, pipeline, set_seed
from transformers.testing_utils import (
is_flaky,
require_accelerate,
@@ -110,6 +110,8 @@ class GenerationTesterMixin:
"decoder_attention_mask",
# we'll set cache use in each test differently
"use_cache",
# Ignore labels if it is in the input dict
"labels",
# model-specific exceptions should overload/overwrite this function
]
filtered_inputs_dict = {
@@ -1564,6 +1566,7 @@ class GenerationTesterMixin:
continue
# Skip models without explicit support
config.is_decoder = True
model = model_class(config).to(torch_device).eval()
if "inputs_embeds" not in inspect.signature(model.prepare_inputs_for_generation).parameters.keys():
continue
@@ -3725,6 +3728,90 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
self.assertEqual(generated_text_no_padding, generated_text_with_padding)
self.assertEqual(generated_text_no_padding, "Ich muss diese Aufgabe vor Ende des Tages beenden.")
def test_prepare_inputs_for_generation_decoder_llm(self):
"""Tests GenerationMixin.prepare_inputs_for_generation against expected usage with decoder-only llms."""
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-LlamaForCausalLM")
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-LlamaForCausalLM")
model = model.to(torch_device)
# 1. Sanity check: the model's `prepare_inputs_for_generation` comes from `GenerationMixin`
self.assertTrue("GenerationMixin" in str(model.prepare_inputs_for_generation))
# 2. If we pass input ids by themselves, we should get back the same input ids
input_ids = torch.tensor([[1, 2, 3], [4, 5, 6]]).to(torch_device)
model_inputs = model.prepare_inputs_for_generation(input_ids)
self.assertTrue(torch.all(model_inputs["input_ids"] == input_ids))
# 3. If we pass the attention mask too, we will get back the attention mask and position ids built from it
attention_mask = torch.tensor([[1, 1, 1], [1, 1, 1]]).to(torch_device)
model_inputs = model.prepare_inputs_for_generation(input_ids, attention_mask=attention_mask)
self.assertTrue(torch.all(model_inputs["attention_mask"] == attention_mask))
self.assertTrue(model_inputs["position_ids"].shape == input_ids.shape)
# 4. `use_cache` (and other kwargs) are forwarded
self.assertFalse("use_cache" in model_inputs) # From the previous input, there is no `use_cache`
model_inputs = model.prepare_inputs_for_generation(input_ids, use_cache=True, foo="bar")
self.assertTrue(model_inputs["use_cache"] is True)
self.assertTrue(model_inputs["foo"] == "bar")
# 5. When we pass a cache, we discard data related to already seen tokens in some tensors. We are now also
# forced to pass a correctly prepared `cache_positions` to slice the data accordingly.
init_input_ids = input_ids[:, :2]
dynamic_cache = DynamicCache()
dynamic_cache = model(init_input_ids, past_key_values=dynamic_cache).past_key_values
with self.assertRaises(AttributeError): # past_key_values + no cache_position -> exception
model_inputs = model.prepare_inputs_for_generation(input_ids, past_key_values=dynamic_cache)
cache_position = torch.arange(input_ids.shape[-1], dtype=torch.long).to(torch_device)
cache_position = cache_position[dynamic_cache.get_seq_length() :]
model_inputs = model.prepare_inputs_for_generation(
input_ids, past_key_values=dynamic_cache, cache_position=cache_position, attention_mask=attention_mask
)
self.assertTrue("past_key_values" in model_inputs)
self.assertTrue(torch.all(model_inputs["cache_position"] == cache_position))
self.assertTrue(model_inputs["input_ids"].shape[-1] == 1) # 1 = 3 fed tokens - 2 tokens in the cache
self.assertTrue(model_inputs["position_ids"].shape[-1] == 1)
self.assertTrue(model_inputs["attention_mask"].shape[-1] == 3) # we still need the full attention mask!
# 6. If we pass a `static_cache`, the attention mask will be prepared as a static shape 4D mask
max_cache_len = 10
batch_size = 2
query_length = input_ids.shape[-1] - init_input_ids.shape[-1]
static_cache = StaticCache(
config=config, batch_size=batch_size, max_cache_len=max_cache_len, device=torch_device, dtype=torch.float32
)
static_cache = model(init_input_ids, past_key_values=static_cache).past_key_values
model_inputs = model.prepare_inputs_for_generation(
input_ids, past_key_values=static_cache, cache_position=cache_position, attention_mask=attention_mask
)
self.assertTrue("past_key_values" in model_inputs)
self.assertTrue(list(model_inputs["attention_mask"].shape) == [batch_size, 1, query_length, max_cache_len])
# 7. We can also pass `inputs_embeds` as the embedded prompt. Because `generate` will append its result to
# `input_ids` and the models will only accept one of the two inputs (`input_ids` or `inputs_embeds`), we
# a) must use the cache b) must expect `input_ids` after the prompt is processed
init_inputs_embeds = model.get_input_embeddings()(init_input_ids)
init_cache_positions = torch.arange(init_input_ids.shape[-1], dtype=torch.long).to(torch_device)
empty_cache = DynamicCache()
# Prompt processing
model_inputs = model.prepare_inputs_for_generation(
init_input_ids,
past_key_values=empty_cache,
inputs_embeds=init_inputs_embeds,
cache_position=init_cache_positions,
)
self.assertTrue(model_inputs["input_ids"] is None)
self.assertTrue(model_inputs["inputs_embeds"] is not None)
# After prompt processing
model_inputs = model.prepare_inputs_for_generation(
input_ids, past_key_values=dynamic_cache, inputs_embeds=init_inputs_embeds, cache_position=cache_position
)
self.assertTrue(model_inputs["input_ids"] is not None)
self.assertTrue(model_inputs["inputs_embeds"] is None)
def test_generate_compile_fullgraph_tiny(self):
"""
Tests that we can call end-to-end generation with a tiny model (i.e. doesn't crash)