Cache: new Cache format in decoder-only models (#31421)
* draft bart with new cache * add cache for decoder-only models * revert utils * modify docstring * revert bart * minor fixes * fix copies (not related) * revert tests * remove enc-dec related code * remove bloom * remove opt (enc-dec) * update docstring * git, codegen, gpt_neo, gpt_neox, gpj * clean up * copied from statements * revert * tmp * update warning msg * forgot git * add more flags * run-slow git,codegen,gpt_neo,gpt_neox,gpj * add cache flag to VLMs * remove files * style * video LLMs also need a flag * style * llava will go in another PR * style * [run-slow] codegen, falcon, git, gpt_neo, gpt_neox, gptj, idefics * Update src/transformers/models/gpt_neo/modeling_gpt_neo.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * copy from * deprecate until v4.45 and warn if not training * nit * fix test * test static cache * add more tests and fix models * fix copies * return sliding window mask * run slow tests & fix + codestyle * one more falcon fix for alibi --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
6af0854efa
commit
a30c865f99
@@ -4587,6 +4587,44 @@ class ModelTesterMixin:
|
||||
normalized_1 = F.softmax(out_shared_prefix_last_tokens)
|
||||
torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-4)
|
||||
|
||||
def test_static_cache_matches_dynamic(self):
|
||||
"""
|
||||
Tests that generating with static cache give almost same results as with dynamic cache.
|
||||
This test does not compile the model and check only logits similarity for numerical precision
|
||||
errors.
|
||||
"""
|
||||
if len(self.all_generative_model_classes) == 0:
|
||||
self.skipTest(
|
||||
reason="Model architecture has no generative classes, and thus not necessarily supporting 4D masks"
|
||||
)
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if not model_class._supports_static_cache:
|
||||
self.skipTest(f"{model_class.__name__} does not support static cache")
|
||||
|
||||
if not model_class._supports_cache_class:
|
||||
self.skipTest(f"{model_class.__name__} does not support cache class")
|
||||
|
||||
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
if getattr(config, "sliding_window", 0) > 0:
|
||||
self.skipTest(f"{model_class.__name__} with sliding window attention is not supported by this test")
|
||||
|
||||
model = model_class(config).to(device=torch_device, dtype=torch.float32)
|
||||
model.eval()
|
||||
|
||||
dynamic_out = model.generate(
|
||||
**inputs, do_sample=False, max_new_tokens=10, output_logits=True, return_dict_in_generate=True
|
||||
)
|
||||
static_out = model.generate(
|
||||
**inputs,
|
||||
do_sample=False,
|
||||
max_new_tokens=10,
|
||||
cache_implementation="static",
|
||||
output_logits=True,
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
self.assertTrue(torch.allclose(dynamic_out.logits[0], static_out.logits[0], rtol=1e-3, atol=1e-4))
|
||||
|
||||
# For now, Let's focus only on GPU for `torch.compile`
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
|
||||
Reference in New Issue
Block a user