[Cache] Don't initialize the cache on meta device (#36543)
This commit is contained in:
@@ -2304,45 +2304,6 @@ class GenerationTesterMixin:
|
||||
without_all_logits = model.generate(**inputs_dict, **generation_kwargs)
|
||||
self.assertEqual(with_all_logits.tolist(), without_all_logits.tolist())
|
||||
|
||||
@pytest.mark.generate
|
||||
@is_flaky
|
||||
def test_assisted_decoding_with_logits_to_keep(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if "logits_to_keep" not in set(inspect.signature(model_class.forward).parameters.keys()):
|
||||
self.skipTest(reason="This model does not support `logits_to_keep` argument.")
|
||||
if model_class._is_stateful:
|
||||
self.skipTest(reason="Stateful models don't support assisted generation")
|
||||
|
||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1)
|
||||
# NOTE: assisted generation only works with cache on at the moment.
|
||||
if not hasattr(config.get_text_config(), "use_cache"):
|
||||
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
|
||||
config.use_cache = True
|
||||
config.is_decoder = True
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
assistant_model = model
|
||||
# All generation methods (except assisted decoding) rely on always extracting the last token logits of the
|
||||
# full logits matrix, so testing out only greedy search and assisted decoding is enough (if it works,
|
||||
# other methods will work as well)
|
||||
generation_kwargs = {
|
||||
"max_new_tokens": 10,
|
||||
"do_sample": False,
|
||||
"assistant_model": assistant_model,
|
||||
"return_dict_in_generate": True,
|
||||
"output_scores": True,
|
||||
}
|
||||
logits_processor_kwargs = self._get_logits_processor_kwargs(config=model.config)
|
||||
|
||||
# Setting logits_to_keep at 0 keeps all logits (old behavior)
|
||||
with_all_logits = model.generate(
|
||||
**generation_kwargs, **inputs_dict, **logits_processor_kwargs, logits_to_keep=0
|
||||
)
|
||||
# By default, logits_to_keep is automatically set to 1 if not provided (new behavior)
|
||||
without_all_logits = model.generate(**inputs_dict, **generation_kwargs, **logits_processor_kwargs)
|
||||
|
||||
self._check_similar_generate_outputs(with_all_logits, without_all_logits)
|
||||
|
||||
@pytest.mark.generate
|
||||
def test_inherits_generation_mixin(self):
|
||||
"""
|
||||
|
||||
@@ -20,6 +20,7 @@ from parameterized import parameterized
|
||||
|
||||
from transformers import set_seed
|
||||
from transformers.testing_utils import (
|
||||
CaptureStderr,
|
||||
get_gpu_count,
|
||||
is_torch_available,
|
||||
require_gptq,
|
||||
@@ -654,3 +655,42 @@ class CacheIntegrationTest(unittest.TestCase):
|
||||
torch.testing.assert_close(
|
||||
actual=parallelism_cache[layer_idx][kv_idx], expected=no_parallelism_cache[layer_idx][kv_idx]
|
||||
)
|
||||
|
||||
@require_torch_gpu
|
||||
def test_static_cache_no_cuda_graph_skips(self):
|
||||
"""
|
||||
Tests generating with static cache and compilation doesn't skip cuda graphs. Regression test for #36543.
|
||||
|
||||
(? We set `fullgraph=True`, which according to torch docs means it should raise an exception. Instead,
|
||||
messages are being thrown to stderr?)
|
||||
"""
|
||||
model_repo = "hf-internal-testing/tiny-random-MistralForCausalLM"
|
||||
model = AutoModelForCausalLM.from_pretrained(model_repo).to(torch_device)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_repo)
|
||||
inputs = tokenizer(["foo bar"], return_tensors="pt").to(torch_device)
|
||||
|
||||
# on `main`, prior to #36543, this would send stderr messages about cuda graphs being skipped.
|
||||
with CaptureStderr() as cap:
|
||||
model.generate(**inputs, max_new_tokens=2, cache_implementation="static")
|
||||
self.assertEqual(cap.err, "")
|
||||
|
||||
@require_torch_multi_gpu
|
||||
def test_static_cache_multi_gpu(self):
|
||||
"""Regression test for #35164: static cache with multi-gpu"""
|
||||
|
||||
model_id = "google/gemma-2-2b-it"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
|
||||
device_map = {"model.embed_tokens": 0, "model.norm": 1, "model.rotary_emb": 1, "lm_head": 0}
|
||||
num_hidden_layers = 26
|
||||
for i in range(num_hidden_layers):
|
||||
device_map[f"model.layers.{i}"] = 0 if i < 13 else 1
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id,
|
||||
torch_dtype="bfloat16",
|
||||
device_map=device_map,
|
||||
)
|
||||
inputs = tokenizer("Today is a beautiful day!", return_tensors="pt").to(0)
|
||||
_ = model(**inputs)
|
||||
_ = model.generate(**inputs, max_new_tokens=2, cache_implementation="hybrid")
|
||||
|
||||
Reference in New Issue
Block a user