[Cache] Don't initialize the cache on meta device (#36543)
This commit is contained in:
@@ -20,6 +20,7 @@ from parameterized import parameterized
|
||||
|
||||
from transformers import set_seed
|
||||
from transformers.testing_utils import (
|
||||
CaptureStderr,
|
||||
get_gpu_count,
|
||||
is_torch_available,
|
||||
require_gptq,
|
||||
@@ -654,3 +655,42 @@ class CacheIntegrationTest(unittest.TestCase):
|
||||
torch.testing.assert_close(
|
||||
actual=parallelism_cache[layer_idx][kv_idx], expected=no_parallelism_cache[layer_idx][kv_idx]
|
||||
)
|
||||
|
||||
@require_torch_gpu
|
||||
def test_static_cache_no_cuda_graph_skips(self):
|
||||
"""
|
||||
Tests generating with static cache and compilation doesn't skip cuda graphs. Regression test for #36543.
|
||||
|
||||
(? We set `fullgraph=True`, which according to torch docs means it should raise an exception. Instead,
|
||||
messages are being thrown to stderr?)
|
||||
"""
|
||||
model_repo = "hf-internal-testing/tiny-random-MistralForCausalLM"
|
||||
model = AutoModelForCausalLM.from_pretrained(model_repo).to(torch_device)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_repo)
|
||||
inputs = tokenizer(["foo bar"], return_tensors="pt").to(torch_device)
|
||||
|
||||
# on `main`, prior to #36543, this would send stderr messages about cuda graphs being skipped.
|
||||
with CaptureStderr() as cap:
|
||||
model.generate(**inputs, max_new_tokens=2, cache_implementation="static")
|
||||
self.assertEqual(cap.err, "")
|
||||
|
||||
@require_torch_multi_gpu
|
||||
def test_static_cache_multi_gpu(self):
|
||||
"""Regression test for #35164: static cache with multi-gpu"""
|
||||
|
||||
model_id = "google/gemma-2-2b-it"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
|
||||
device_map = {"model.embed_tokens": 0, "model.norm": 1, "model.rotary_emb": 1, "lm_head": 0}
|
||||
num_hidden_layers = 26
|
||||
for i in range(num_hidden_layers):
|
||||
device_map[f"model.layers.{i}"] = 0 if i < 13 else 1
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id,
|
||||
torch_dtype="bfloat16",
|
||||
device_map=device_map,
|
||||
)
|
||||
inputs = tokenizer("Today is a beautiful day!", return_tensors="pt").to(0)
|
||||
_ = model(**inputs)
|
||||
_ = model.generate(**inputs, max_new_tokens=2, cache_implementation="hybrid")
|
||||
|
||||
Reference in New Issue
Block a user