Make Gemma work with torch.compile (#30775)

* fix

* [run-slow] gemma

* add test

* add `test_compile_static_cache`

* fix

* style

* remove subprocess

* use attribute

* fix

* style

* update

* [run-slow] dbrx,gemma,jetmoe,phi3,recurrent_gemma

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar
2024-05-16 13:41:33 +02:00
committed by GitHub
parent 0753134f4d
commit 1b3dba9417
8 changed files with 110 additions and 26 deletions

View File

@@ -27,6 +27,7 @@ from collections import defaultdict
from typing import Dict, List, Tuple
import numpy as np
from packaging import version
from parameterized import parameterized
from pytest import mark
@@ -35,6 +36,7 @@ from transformers import (
AutoModel,
AutoModelForCausalLM,
AutoModelForSequenceClassification,
AutoTokenizer,
PretrainedConfig,
PreTrainedModel,
is_torch_available,
@@ -71,6 +73,7 @@ from transformers.testing_utils import (
require_accelerate,
require_bitsandbytes,
require_flash_attn,
require_read_token,
require_safetensors,
require_torch,
require_torch_gpu,
@@ -4399,6 +4402,38 @@ class ModelTesterMixin:
normalized_1 = F.softmax(out_shared_prefix_last_tokens)
torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-4)
# For now, Let's focus only on GPU for `torch.compile`
@slow
@require_torch_gpu
@require_read_token
def test_torch_compile(self):
if version.parse(torch.__version__) < version.parse("2.3"):
self.skipTest("This test requires torch >= 2.3 to run.")
if not hasattr(self, "_torch_compile_test_ckpt"):
self.skipTest(f"{self.__class__.__name__} doesn't have the attribute `_torch_compile_test_ckpt`.")
ckpt = self._torch_compile_test_ckpt
os.environ["TOKENIZERS_PARALLELISM"] = "false"
batch_size = 1
n_iter = 3
tokenizer = AutoTokenizer.from_pretrained(ckpt)
model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.float16).to(torch_device)
model.generation_config.max_new_tokens = 4
model.generation_config.max_new_tokens = 4
model.generation_config.cache_implementation = "static"
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
input_text = "Why dogs are cute?"
input_ids = tokenizer([input_text] * batch_size, return_tensors="pt").to(torch_device)
for i in range(n_iter):
_ = model.generate(**input_ids, do_sample=False)
global_rng = random.Random()