Make Gemma work with torch.compile (#30775)
* fix * [run-slow] gemma * add test * add `test_compile_static_cache` * fix * style * remove subprocess * use attribute * fix * style * update * [run-slow] dbrx,gemma,jetmoe,phi3,recurrent_gemma --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -17,6 +17,7 @@ import tempfile
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
from packaging import version
|
||||
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, GemmaConfig, is_torch_available
|
||||
from transformers.testing_utils import (
|
||||
@@ -40,7 +41,7 @@ from ...test_pipeline_mixin import PipelineTesterMixin
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import GemmaForCausalLM, GemmaForSequenceClassification, GemmaModel
|
||||
from transformers import GemmaForCausalLM, GemmaForSequenceClassification, GemmaModel, GemmaTokenizer
|
||||
|
||||
|
||||
class GemmaModelTester:
|
||||
@@ -302,6 +303,9 @@ class GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
||||
# This is because we are hitting edge cases with the causal_mask buffer
|
||||
model_split_percents = [0.5, 0.6]
|
||||
|
||||
# used in `test_torch_compile`
|
||||
_torch_compile_test_ckpt = "google/gemma-2b"
|
||||
|
||||
# TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146
|
||||
def is_pipeline_test_to_skip(
|
||||
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
|
||||
@@ -801,3 +805,51 @@ class GemmaIntegrationTest(unittest.TestCase):
|
||||
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
|
||||
|
||||
self.assertEqual(output_text, EXPECTED_TEXTS[self.cuda_compute_capability_major_version])
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
@require_read_token
|
||||
def test_compile_static_cache(self):
|
||||
# `torch==2.2` will throw an error on this test (as in other compilation tests), but torch==2.1.2 and torch>2.2
|
||||
# work as intended. See https://github.com/pytorch/pytorch/issues/121943
|
||||
if version.parse(torch.__version__) < version.parse("2.3.0"):
|
||||
self.skipTest("This test requires torch >= 2.3 to run.")
|
||||
|
||||
NUM_TOKENS_TO_GENERATE = 40
|
||||
# Note on `EXPECTED_TEXT_COMPLETION`'s diff: the current value matches the original test if the original test
|
||||
# was changed to have a cache of 53 tokens (as opposed to 4096), on Ampere GPUs.
|
||||
EXPECTED_TEXT_COMPLETION = {
|
||||
8: [
|
||||
"Hello I am doing a project on the 1990s and I need to know what the most popular music was in the 1990s. I have looked on the internet and I have found",
|
||||
"Hi today\nI have a problem with my 2007 1.9 tdi 105bhp.\nI have a problem with the engine management light on.\nI have checked the",
|
||||
],
|
||||
7: [
|
||||
"Hello I am doing a project on the 1990s and I need to know what the most popular music was in the 1990s. I have looked on the internet and I have found",
|
||||
"Hi today\nI have a problem with my 2007 1.9 tdi 105bhp.\nI have a problem with the engine management light on.\nI have checked the",
|
||||
],
|
||||
}
|
||||
|
||||
prompts = ["Hello I am doing", "Hi today"]
|
||||
tokenizer = GemmaTokenizer.from_pretrained("google/gemma-2b", pad_token="</s>", padding_side="right")
|
||||
model = GemmaForCausalLM.from_pretrained("google/gemma-2b", device_map="sequential", torch_dtype=torch.float16)
|
||||
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)
|
||||
|
||||
# Dynamic Cache
|
||||
generated_ids = model.generate(**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False)
|
||||
dynamic_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
self.assertEqual(EXPECTED_TEXT_COMPLETION[8], dynamic_text) # Both GPU architectures have the same output
|
||||
|
||||
# Static Cache
|
||||
generated_ids = model.generate(
|
||||
**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static"
|
||||
)
|
||||
static_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], static_text)
|
||||
|
||||
# Static Cache + compile
|
||||
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
|
||||
generated_ids = model.generate(
|
||||
**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static"
|
||||
)
|
||||
static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], static_compiled_text)
|
||||
|
||||
Reference in New Issue
Block a user