Cache: revert DynamicCache init for BC (#33861)

* tmp commit

* tmp commit

* make fixup

* missing removal

* fix condition

* fix end-to-end compilation

* if -> elif

* BC

* BC

* use @deprecate_kwarg("num_hidden_layers", version="4.47.0")

* wups the import

* 🥴

---------

Co-authored-by: Arthur Zucker <arthur.zucker@gmail.com>
This commit is contained in:
Joao Gante
2024-10-04 21:47:08 +01:00
committed by GitHub
parent f92d354823
commit 38f9f10dd9
5 changed files with 113 additions and 56 deletions

View File

@@ -1776,13 +1776,13 @@ class GenerationTesterMixin:
set_seed(seed)
legacy_results = model.generate(**generation_kwargs, **inputs_dict)
set_seed(seed)
num_hidden_layers = config.get_text_config().num_hidden_layers
if config.is_encoder_decoder:
cache_cls = EncoderDecoderCache
past_key_values = cache_cls(DynamicCache(num_hidden_layers), DynamicCache(num_hidden_layers))
past_key_values = cache_cls(DynamicCache(), DynamicCache())
else:
cache_cls = DynamicCache
past_key_values = cache_cls()
new_results = model.generate(past_key_values=past_key_values, **generation_kwargs, **inputs_dict)
# The two sets of generated sequences must match, despite the cache format between forward passes being
@@ -3725,6 +3725,29 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
self.assertEqual(generated_text_no_padding, generated_text_with_padding)
self.assertEqual(generated_text_no_padding, "Ich muss diese Aufgabe vor Ende des Tages beenden.")
def test_generate_compile_fullgraph_tiny(self):
"""
Tests that we can call end-to-end generation with a tiny model (i.e. doesn't crash)
NOTE: this test is quite slow (~20s on a consumer desktop), but it is important that we keep it as part of the
non-slow tests to prevent regressions!
"""
model = AutoModelForCausalLM.from_pretrained(
"hf-internal-testing/tiny-random-LlamaForCausalLM", torch_dtype=torch.bfloat16, device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-LlamaForCausalLM")
# compile generate
compiled_generate = torch.compile(model.generate, fullgraph=True, mode="reduce-overhead")
# compiled generate does NOT accept parameterization except a) model inputs b) a generation config
generation_config = copy.deepcopy(model.generation_config)
generation_config.pad_token_id = model.config.eos_token_id
model_inputs = tokenizer(["Write a poem about the market crashing in summer"], return_tensors="pt")
model_inputs = model_inputs.to(model.device)
gen_out = compiled_generate(**model_inputs, generation_config=generation_config)
self.assertTrue(gen_out.shape[1] > model_inputs["input_ids"].shape[1]) # some text was generated
@require_torch
class TokenHealingTestCase(unittest.TestCase):