[Cache] Don't initialize the cache on meta device (#36543)

This commit is contained in:
Joao Gante
2025-03-13 10:13:29 +00:00
committed by GitHub
parent 79254c9b61
commit c4161238bd
9 changed files with 138 additions and 147 deletions

View File

@@ -2304,45 +2304,6 @@ class GenerationTesterMixin:
without_all_logits = model.generate(**inputs_dict, **generation_kwargs)
self.assertEqual(with_all_logits.tolist(), without_all_logits.tolist())
@pytest.mark.generate
@is_flaky
def test_assisted_decoding_with_logits_to_keep(self):
for model_class in self.all_generative_model_classes:
if "logits_to_keep" not in set(inspect.signature(model_class.forward).parameters.keys()):
self.skipTest(reason="This model does not support `logits_to_keep` argument.")
if model_class._is_stateful:
self.skipTest(reason="Stateful models don't support assisted generation")
config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1)
# NOTE: assisted generation only works with cache on at the moment.
if not hasattr(config.get_text_config(), "use_cache"):
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
config.use_cache = True
config.is_decoder = True
model = model_class(config).to(torch_device).eval()
assistant_model = model
# All generation methods (except assisted decoding) rely on always extracting the last token logits of the
# full logits matrix, so testing out only greedy search and assisted decoding is enough (if it works,
# other methods will work as well)
generation_kwargs = {
"max_new_tokens": 10,
"do_sample": False,
"assistant_model": assistant_model,
"return_dict_in_generate": True,
"output_scores": True,
}
logits_processor_kwargs = self._get_logits_processor_kwargs(config=model.config)
# Setting logits_to_keep at 0 keeps all logits (old behavior)
with_all_logits = model.generate(
**generation_kwargs, **inputs_dict, **logits_processor_kwargs, logits_to_keep=0
)
# By default, logits_to_keep is automatically set to 1 if not provided (new behavior)
without_all_logits = model.generate(**inputs_dict, **generation_kwargs, **logits_processor_kwargs)
self._check_similar_generate_outputs(with_all_logits, without_all_logits)
@pytest.mark.generate
def test_inherits_generation_mixin(self):
"""