Cache: revert DynamicCache init for BC (#33861)
* tmp commit
* tmp commit
* make fixup
* missing removal
* fix condition
* fix end-to-end compilation
* if -> elif
* BC
* BC
* use @deprecate_kwarg("num_hidden_layers", version="4.47.0")
* wups the import
* 🥴
---------
Co-authored-by: Arthur Zucker <arthur.zucker@gmail.com>
This commit is contained in:
@@ -1776,13 +1776,13 @@ class GenerationTesterMixin:
|
||||
set_seed(seed)
|
||||
legacy_results = model.generate(**generation_kwargs, **inputs_dict)
|
||||
set_seed(seed)
|
||||
num_hidden_layers = config.get_text_config().num_hidden_layers
|
||||
if config.is_encoder_decoder:
|
||||
cache_cls = EncoderDecoderCache
|
||||
past_key_values = cache_cls(DynamicCache(num_hidden_layers), DynamicCache(num_hidden_layers))
|
||||
past_key_values = cache_cls(DynamicCache(), DynamicCache())
|
||||
else:
|
||||
cache_cls = DynamicCache
|
||||
past_key_values = cache_cls()
|
||||
|
||||
new_results = model.generate(past_key_values=past_key_values, **generation_kwargs, **inputs_dict)
|
||||
|
||||
# The two sets of generated sequences must match, despite the cache format between forward passes being
|
||||
@@ -3725,6 +3725,29 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
self.assertEqual(generated_text_no_padding, generated_text_with_padding)
|
||||
self.assertEqual(generated_text_no_padding, "Ich muss diese Aufgabe vor Ende des Tages beenden.")
|
||||
|
||||
def test_generate_compile_fullgraph_tiny(self):
|
||||
"""
|
||||
Tests that we can call end-to-end generation with a tiny model (i.e. doesn't crash)
|
||||
NOTE: this test is quite slow (~20s on a consumer desktop), but it is important that we keep it as part of the
|
||||
non-slow tests to prevent regressions!
|
||||
"""
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"hf-internal-testing/tiny-random-LlamaForCausalLM", torch_dtype=torch.bfloat16, device_map="auto"
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-LlamaForCausalLM")
|
||||
|
||||
# compile generate
|
||||
compiled_generate = torch.compile(model.generate, fullgraph=True, mode="reduce-overhead")
|
||||
|
||||
# compiled generate does NOT accept parameterization except a) model inputs b) a generation config
|
||||
generation_config = copy.deepcopy(model.generation_config)
|
||||
generation_config.pad_token_id = model.config.eos_token_id
|
||||
|
||||
model_inputs = tokenizer(["Write a poem about the market crashing in summer"], return_tensors="pt")
|
||||
model_inputs = model_inputs.to(model.device)
|
||||
gen_out = compiled_generate(**model_inputs, generation_config=generation_config)
|
||||
self.assertTrue(gen_out.shape[1] > model_inputs["input_ids"].shape[1]) # some text was generated
|
||||
|
||||
|
||||
@require_torch
|
||||
class TokenHealingTestCase(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user