Make Gemma work with torch.compile (#30775)
* fix * [run-slow] gemma * add test * add `test_compile_static_cache` * fix * style * remove subprocess * use attribute * fix * style * update * [run-slow] dbrx,gemma,jetmoe,phi3,recurrent_gemma --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -27,6 +27,7 @@ from collections import defaultdict
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
from packaging import version
|
||||
from parameterized import parameterized
|
||||
from pytest import mark
|
||||
|
||||
@@ -35,6 +36,7 @@ from transformers import (
|
||||
AutoModel,
|
||||
AutoModelForCausalLM,
|
||||
AutoModelForSequenceClassification,
|
||||
AutoTokenizer,
|
||||
PretrainedConfig,
|
||||
PreTrainedModel,
|
||||
is_torch_available,
|
||||
@@ -71,6 +73,7 @@ from transformers.testing_utils import (
|
||||
require_accelerate,
|
||||
require_bitsandbytes,
|
||||
require_flash_attn,
|
||||
require_read_token,
|
||||
require_safetensors,
|
||||
require_torch,
|
||||
require_torch_gpu,
|
||||
@@ -4399,6 +4402,38 @@ class ModelTesterMixin:
|
||||
normalized_1 = F.softmax(out_shared_prefix_last_tokens)
|
||||
torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-4)
|
||||
|
||||
# For now, Let's focus only on GPU for `torch.compile`
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
@require_read_token
|
||||
def test_torch_compile(self):
|
||||
if version.parse(torch.__version__) < version.parse("2.3"):
|
||||
self.skipTest("This test requires torch >= 2.3 to run.")
|
||||
|
||||
if not hasattr(self, "_torch_compile_test_ckpt"):
|
||||
self.skipTest(f"{self.__class__.__name__} doesn't have the attribute `_torch_compile_test_ckpt`.")
|
||||
ckpt = self._torch_compile_test_ckpt
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
batch_size = 1
|
||||
n_iter = 3
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(ckpt)
|
||||
model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.float16).to(torch_device)
|
||||
|
||||
model.generation_config.max_new_tokens = 4
|
||||
model.generation_config.max_new_tokens = 4
|
||||
|
||||
model.generation_config.cache_implementation = "static"
|
||||
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
|
||||
|
||||
input_text = "Why dogs are cute?"
|
||||
input_ids = tokenizer([input_text] * batch_size, return_tensors="pt").to(torch_device)
|
||||
|
||||
for i in range(n_iter):
|
||||
_ = model.generate(**input_ids, do_sample=False)
|
||||
|
||||
|
||||
global_rng = random.Random()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user