>3-5x faster torch.compile forward compilation for autoregressive decoder models (#32227)
* draft * apply changes to all relevant archs * rerun ci - check_docstrings.py failing? * fix docstring * move 2D->4D mask creation to modeling file * repo consistency * fix the batch size = 1 case - calling contiguous is not enough * nit * style * propagate to gemma/gemma-2 * prepare inputs for gemma generation * implement test and tiny fix in gemma2 * Update src/transformers/models/bloom/modeling_bloom.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * fix copies * ci pass * fix gemma's test_compile_static_cache tests * flacky * retrigger ci --------- Co-authored-by: sanchit-gandhi <sanchit@huggingface.co> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
@@ -22,6 +22,7 @@ import os.path
|
||||
import random
|
||||
import re
|
||||
import tempfile
|
||||
import time
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List, Tuple
|
||||
@@ -37,6 +38,7 @@ from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoModelForSequenceClassification,
|
||||
AutoTokenizer,
|
||||
GenerationConfig,
|
||||
PretrainedConfig,
|
||||
PreTrainedModel,
|
||||
is_torch_available,
|
||||
@@ -4605,7 +4607,6 @@ class ModelTesterMixin:
|
||||
tokenizer = AutoTokenizer.from_pretrained(ckpt)
|
||||
model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.float16).to(torch_device)
|
||||
|
||||
model.generation_config.max_new_tokens = 4
|
||||
model.generation_config.max_new_tokens = 4
|
||||
|
||||
model.generation_config.cache_implementation = "static"
|
||||
@@ -4617,6 +4618,66 @@ class ModelTesterMixin:
|
||||
for i in range(n_iter):
|
||||
_ = model.generate(**input_ids, do_sample=False)
|
||||
|
||||
@slow
|
||||
@require_torch_gpu # Testing cuda graphs.
|
||||
@require_read_token
|
||||
def test_compile_cuda_graph_time(self):
|
||||
if version.parse(torch.__version__) < version.parse("2.3"):
|
||||
self.skipTest(reason="This test requires torch >= 2.3 to run.")
|
||||
|
||||
# TODO felix: All models supporting `StaticCache` or `torch.compile` should be tested.
|
||||
# At the moment, only llama, gemma and gemma2 are tested here!
|
||||
if not hasattr(self, "_torch_compile_test_ckpt"):
|
||||
self.skipTest(f"{self.__class__.__name__} doesn't have the attribute `_torch_compile_test_ckpt`.")
|
||||
ckpt = self._torch_compile_test_ckpt
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(ckpt)
|
||||
model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.float16).to(torch_device)
|
||||
|
||||
cache_implementation = "static"
|
||||
if model.config.model_type == "gemma2":
|
||||
cache_implementation = "hybrid"
|
||||
|
||||
new_tokens = 50
|
||||
gen_config = GenerationConfig(
|
||||
max_new_tokens=new_tokens,
|
||||
min_new_tokens=new_tokens,
|
||||
use_cache=True,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
num_beams=1,
|
||||
do_sample=False,
|
||||
eos_token_id=None, # This is required for min_new_tokens to actually have an effect.
|
||||
)
|
||||
model.generation_config.eos_token_id = None # greedy_search falls back on this eos_token_id that we need to set to None as well for min_new_tokens to have an effect.
|
||||
|
||||
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
|
||||
|
||||
inp = tokenizer("Why cats are cute?", return_tensors="pt").to(torch_device)
|
||||
|
||||
# First run: the first run warms up each graph, which does things like CuBlas or Triton benchmarking
|
||||
start = time.perf_counter()
|
||||
_ = model.generate(**inp, generation_config=gen_config, cache_implementation=cache_implementation)
|
||||
end = time.perf_counter()
|
||||
graph_warmup_time = end - start
|
||||
|
||||
# Second run: CUDA Graph recording, and replays it
|
||||
start = time.perf_counter()
|
||||
_ = model.generate(**inp, generation_config=gen_config, cache_implementation=cache_implementation)
|
||||
end = time.perf_counter()
|
||||
record_time = end - start
|
||||
|
||||
# Finally: we hit the optimized, CUDA Graph replay path
|
||||
start = time.perf_counter()
|
||||
_ = model.generate(**inp, generation_config=gen_config, cache_implementation=cache_implementation)
|
||||
end = time.perf_counter()
|
||||
opt_time = end - start
|
||||
|
||||
# For the recording step, we expect only two cuda graphs and this step should be much faster than the first.
|
||||
self.assertTrue(record_time < 0.15 * graph_warmup_time)
|
||||
self.assertTrue(opt_time < record_time)
|
||||
|
||||
|
||||
global_rng = random.Random()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user