Automatic compilation in generate: do not rely on inner function (#34923)

* compiled forward in PreTrainedModel

* update

* style

* update name

* trigger CIs

* Add way to use custom compile args

* style

* switch parameterization to generation_config

* Add to inits

* Update configuration_utils.py

* inits

* style

* docs

* style

* Update configuration_utils.py

* back without dataclass for repo consistency

* Update configuration_utils.py

* style

* style

* style once again

* add config serialization

* update

* true dataclass

* trigger CIs

* merge compile methods + remove serialization of compile config
This commit is contained in:
Cyril Vallez
2024-12-03 11:20:31 +01:00
committed by GitHub
parent f9c7e6021e
commit ee37bf0d95
6 changed files with 99 additions and 12 deletions

View File

@@ -436,3 +436,9 @@ A [`Constraint`] can be used to force the generation to include specific tokens
[[autodoc]] SynthIDTextWatermarkDetector [[autodoc]] SynthIDTextWatermarkDetector
- __call__ - __call__
## Compile Utils
[[autodoc]] CompileConfig
- __call__

View File

@@ -122,6 +122,7 @@ _import_structure = {
"feature_extraction_utils": ["BatchFeature", "FeatureExtractionMixin"], "feature_extraction_utils": ["BatchFeature", "FeatureExtractionMixin"],
"file_utils": [], "file_utils": [],
"generation": [ "generation": [
"CompileConfig",
"GenerationConfig", "GenerationConfig",
"TextIteratorStreamer", "TextIteratorStreamer",
"TextStreamer", "TextStreamer",
@@ -4981,7 +4982,7 @@ if TYPE_CHECKING:
from .feature_extraction_utils import BatchFeature, FeatureExtractionMixin from .feature_extraction_utils import BatchFeature, FeatureExtractionMixin
# Generation # Generation
from .generation import GenerationConfig, TextIteratorStreamer, TextStreamer, WatermarkingConfig from .generation import CompileConfig, GenerationConfig, TextIteratorStreamer, TextStreamer, WatermarkingConfig
from .hf_argparser import HfArgumentParser from .hf_argparser import HfArgumentParser
# Integrations # Integrations

View File

@@ -20,6 +20,7 @@ from ..utils import OptionalDependencyNotAvailable, _LazyModule, is_flax_availab
_import_structure = { _import_structure = {
"configuration_utils": [ "configuration_utils": [
"BaseWatermarkingConfig", "BaseWatermarkingConfig",
"CompileConfig",
"GenerationConfig", "GenerationConfig",
"GenerationMode", "GenerationMode",
"SynthIDTextWatermarkingConfig", "SynthIDTextWatermarkingConfig",
@@ -192,6 +193,7 @@ else:
if TYPE_CHECKING: if TYPE_CHECKING:
from .configuration_utils import ( from .configuration_utils import (
BaseWatermarkingConfig, BaseWatermarkingConfig,
CompileConfig,
GenerationConfig, GenerationConfig,
GenerationMode, GenerationMode,
SynthIDTextWatermarkingConfig, SynthIDTextWatermarkingConfig,

View File

@@ -20,7 +20,7 @@ import os
import warnings import warnings
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass, is_dataclass from dataclasses import dataclass, is_dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
from .. import __version__ from .. import __version__
from ..configuration_utils import PretrainedConfig from ..configuration_utils import PretrainedConfig
@@ -371,6 +371,12 @@ class GenerationConfig(PushToHubMixin):
to correctly align tokens. Can only be used with different tokenizers in speculative decoding. to correctly align tokens. Can only be used with different tokenizers in speculative decoding.
See this [blog](https://huggingface.co/blog/universal_assisted_generation) for more details. See this [blog](https://huggingface.co/blog/universal_assisted_generation) for more details.
> Parameters related to performances and compilation
compile_config (CompileConfig, *optional*):
If using a static cache, this controls how `generate` will `compile` the forward pass for performance
gains.
> Wild card > Wild card
generation_kwargs: generation_kwargs:
@@ -474,6 +480,9 @@ class GenerationConfig(PushToHubMixin):
self.assistant_lookbehind = kwargs.pop("assistant_lookbehind", 10) self.assistant_lookbehind = kwargs.pop("assistant_lookbehind", 10)
self.target_lookbehind = kwargs.pop("target_lookbehind", 10) self.target_lookbehind = kwargs.pop("target_lookbehind", 10)
# Performances
self.compile_config = kwargs.pop("compile_config", CompileConfig())
# Wild card # Wild card
self.generation_kwargs = kwargs.pop("generation_kwargs", {}) self.generation_kwargs = kwargs.pop("generation_kwargs", {})
@@ -794,7 +803,13 @@ class GenerationConfig(PushToHubMixin):
self.watermarking_config = WatermarkingConfig.from_dict(self.watermarking_config) self.watermarking_config = WatermarkingConfig.from_dict(self.watermarking_config)
self.watermarking_config.validate() self.watermarking_config.validate()
# 7. other incorrect combinations # 7. performances arguments
if not isinstance(self.compile_config, CompileConfig):
raise ValueError(
f"You provided `compile_config` as an instance of {type(self.compile_config)}, but it must be an instance of `CompileConfig`."
)
# 8. other incorrect combinations
if self.return_dict_in_generate is not True: if self.return_dict_in_generate is not True:
for extra_output_flag in self.extra_output_flags: for extra_output_flag in self.extra_output_flags:
if getattr(self, extra_output_flag) is True: if getattr(self, extra_output_flag) is True:
@@ -1175,6 +1190,8 @@ class GenerationConfig(PushToHubMixin):
del output["_commit_hash"] del output["_commit_hash"]
if "_original_object_hash" in output: if "_original_object_hash" in output:
del output["_original_object_hash"] del output["_original_object_hash"]
if "compile_config" in output:
del output["compile_config"]
# Transformers version when serializing this file # Transformers version when serializing this file
output["transformers_version"] = __version__ output["transformers_version"] = __version__
@@ -1559,3 +1576,51 @@ class SynthIDTextWatermarkingConfig(BaseWatermarkingConfig):
skip_first_ngram_calls=self.skip_first_ngram_calls, skip_first_ngram_calls=self.skip_first_ngram_calls,
debug_mode=self.debug_mode, debug_mode=self.debug_mode,
) )
@dataclass
class CompileConfig(object):
"""
Class that holds arguments relative to `torch.compile` behavior, when using automatic compilation in `generate`.
See [`torch.compile`](https://pytorch.org/docs/stable/generated/torch.compile.html) for more details on the arguments.
Args:
fullgraph (`bool`, *optional*, defaults to `True`):
If `True`, requires that the whole forward be capturable in a single graph.
dynamic (`bool` or `None`, *optional*):
Whether to try to use dynamic shape graphs.
backend (`str` or `Callable`, *optional*, defaults to `"inductor"`):
Backend to be used.
mode (`str`, *optional*, defaults to `"reduce-overhead"`):
Controls balance between performance and overhead.
options (`dict`, *optional*):
A dictionary of options to pass to the backend.
Examples:
```python
>>> from transformers import AutoModelForCausalLM, AutoTokenizer, CompileConfig
>>> tokenizer = AutoTokenizer.from_pretrained('google/gemma-2-2b')
>>> model = AutoModelForCausalLM.from_pretrained('google/gemma-2-2b').cuda()
>>> # Automatic compile configuration, used with static cache
>>> compile_config = CompileConfig(dynamic=True)
>>> # Generation with static cache and compile config
>>> input = tokenizer.encode("Hello there, how", return_tensors="pt").cuda()
>>> output = model.generate(
... input, do_sample=False, max_new_tokens=300, cache_implementation="static", compile_config=compile_config
... )
>>> output_text = tokenizer.batch_decode(output, skip_special_tokens=True)[0]
```
"""
fullgraph: bool = True
dynamic: Optional[bool] = None
backend: Union[str, Callable] = "inductor"
mode: str = "reduce-overhead"
options: Optional[dict] = None
def to_dict(self) -> Dict[str, Any]:
"""Serializes this instance to a Python dictionary."""
return copy.deepcopy(self.__dict__)

View File

@@ -3230,16 +3230,14 @@ class GenerationMixin:
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs) model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
def model_forward(model, *args, **kwargs): model_forward = self.__call__
return model.forward(*args, **kwargs)
if isinstance(model_kwargs.get("past_key_values"), StaticCache): if isinstance(model_kwargs.get("past_key_values"), StaticCache):
if self.device.type == "cuda": if self.device.type == "cuda":
logger.warning_once("Using `torch.compile`.") logger.warning_once("Using `torch.compile`.")
os.environ["TOKENIZERS_PARALLELISM"] = "0" os.environ["TOKENIZERS_PARALLELISM"] = "0"
model_forward = torch.compile(model_forward, mode="reduce-overhead", fullgraph=True) model_forward = self.get_compiled_call(generation_config.compile_config)
i = 0 is_prefill = True
while self._has_unfinished_sequences( while self._has_unfinished_sequences(
this_peer_finished, synced_gpus, device=input_ids.device, cur_len=cur_len, max_length=max_length this_peer_finished, synced_gpus, device=input_ids.device, cur_len=cur_len, max_length=max_length
): ):
@@ -3250,11 +3248,11 @@ class GenerationMixin:
model_inputs.update({"output_attentions": output_attentions} if output_attentions else {}) model_inputs.update({"output_attentions": output_attentions} if output_attentions else {})
model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {}) model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})
if i == 0: if is_prefill:
outputs = self(**model_inputs, return_dict=True) outputs = self(**model_inputs, return_dict=True)
i += 1 is_prefill = False
else: else:
outputs = model_forward(self, return_dict=True, **model_inputs) outputs = model_forward(**model_inputs, return_dict=True)
# synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
model_kwargs = self._update_model_kwargs_for_generation( model_kwargs = self._update_model_kwargs_for_generation(

View File

@@ -43,7 +43,7 @@ from torch.utils.checkpoint import checkpoint
from .activations import get_activation from .activations import get_activation
from .configuration_utils import PretrainedConfig from .configuration_utils import PretrainedConfig
from .dynamic_module_utils import custom_object_save from .dynamic_module_utils import custom_object_save
from .generation import GenerationConfig, GenerationMixin from .generation import CompileConfig, GenerationConfig, GenerationMixin
from .integrations import PeftAdapterMixin, deepspeed_config, is_deepspeed_zero3_enabled from .integrations import PeftAdapterMixin, deepspeed_config, is_deepspeed_zero3_enabled
from .loss.loss_utils import LOSS_MAPPING from .loss.loss_utils import LOSS_MAPPING
from .pytorch_utils import ( # noqa: F401 from .pytorch_utils import ( # noqa: F401
@@ -5094,6 +5094,21 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
loss_type = "ForCausalLM" loss_type = "ForCausalLM"
return LOSS_MAPPING[loss_type] return LOSS_MAPPING[loss_type]
def get_compiled_call(self, compile_config: CompileConfig):
"""Return a `torch.compile`'d version of `self.__call__`. This is useful to dynamically choose between
non-compiled/compiled `forward` during inference, especially to switch between prefill (where we don't
want to use compiled version to avoid recomputing the graph with new shapes) and iterative decoding
(where we want the speed-ups of compiled version with static shapes)."""
# Only reset it if not present or different from previous config
default_config = getattr(self.generation_config, "compile_config", CompileConfig())
if (
not hasattr(self, "_compiled_call")
or getattr(self, "_last_compile_config", default_config) != compile_config
):
self._last_compile_config = compile_config
self._compiled_call = torch.compile(self.__call__, **compile_config.to_dict())
return self._compiled_call
PreTrainedModel.push_to_hub = copy_func(PreTrainedModel.push_to_hub) PreTrainedModel.push_to_hub = copy_func(PreTrainedModel.push_to_hub)
if PreTrainedModel.push_to_hub.__doc__ is not None: if PreTrainedModel.push_to_hub.__doc__ is not None: