Make StaticCache configurable at model construct time (#32830)
* Make StaticCache configurable at model construct time * integrations import structure * add new doc file to toc --------- Co-authored-by: Guang Yang <guangyang@fb.com> Co-authored-by: Joao Gante <joao@huggingface.co>
This commit is contained in:
@@ -296,6 +296,8 @@
|
||||
title: Trainer
|
||||
- local: main_classes/deepspeed
|
||||
title: DeepSpeed
|
||||
- local: main_classes/executorch
|
||||
title: ExecuTorch
|
||||
- local: main_classes/feature_extractor
|
||||
title: Feature Extractor
|
||||
- local: main_classes/image_processor
|
||||
|
||||
33
docs/source/en/main_classes/executorch.md
Normal file
33
docs/source/en/main_classes/executorch.md
Normal file
@@ -0,0 +1,33 @@
|
||||
<!--Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
|
||||
# ExecuTorch
|
||||
|
||||
[`ExecuTorch`](https://github.com/pytorch/executorch) is an end-to-end solution for enabling on-device inference capabilities across mobile and edge devices including wearables, embedded devices and microcontrollers. It is part of the PyTorch ecosystem and supports the deployment of PyTorch models with a focus on portability, productivity, and performance.
|
||||
|
||||
ExecuTorch introduces well defined entry points to perform model, device, and/or use-case specific optimizations such as backend delegation, user-defined compiler transformations, memory planning, and more. The first step in preparing a PyTorch model for execution on an edge device using ExecuTorch is to export the model. This is achieved through the use of a PyTorch API called [`torch.export`](https://pytorch.org/docs/stable/export.html).
|
||||
|
||||
|
||||
## ExecuTorch Integration
|
||||
|
||||
An integration point is being developed to ensure that 🤗 Transformers can be exported using `torch.export`. The goal of this integration is not only to enable export but also to ensure that the exported artifact can be further lowered and optimized to run efficiently in `ExecuTorch`, particularly for mobile and edge use cases.
|
||||
|
||||
[[autodoc]] integrations.executorch.TorchExportableModuleWithStaticCache
|
||||
- forward
|
||||
|
||||
[[autodoc]] integrations.executorch.convert_and_export_with_cache
|
||||
@@ -1323,6 +1323,13 @@ else:
|
||||
"WhisperTimeStampLogitsProcessor",
|
||||
]
|
||||
)
|
||||
|
||||
# PyTorch domain libraries integration
|
||||
_import_structure["integrations.executorch"] = [
|
||||
"TorchExportableModuleWithStaticCache",
|
||||
"convert_and_export_with_cache",
|
||||
]
|
||||
|
||||
_import_structure["modeling_flash_attention_utils"] = []
|
||||
_import_structure["modeling_outputs"] = []
|
||||
_import_structure["modeling_rope_utils"] = ["ROPE_INIT_FUNCTIONS"]
|
||||
@@ -6121,6 +6128,10 @@ if TYPE_CHECKING:
|
||||
WatermarkLogitsProcessor,
|
||||
WhisperTimeStampLogitsProcessor,
|
||||
)
|
||||
from .integrations.executorch import (
|
||||
TorchExportableModuleWithStaticCache,
|
||||
convert_and_export_with_cache,
|
||||
)
|
||||
from .modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
||||
from .modeling_utils import PreTrainedModel
|
||||
from .models.albert import (
|
||||
|
||||
@@ -293,6 +293,46 @@ class QuantizedCacheConfig(CacheConfig):
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class StaticCacheConfig(CacheConfig):
|
||||
"""
|
||||
Configuration class for static cache settings.
|
||||
"""
|
||||
|
||||
cache_implementation = "static"
|
||||
|
||||
def __init__(self, batch_size: int, max_cache_len: int, device="cpu"):
|
||||
self.batch_size = batch_size
|
||||
self.max_cache_len = max_cache_len
|
||||
self.device = device
|
||||
|
||||
def validate(self):
|
||||
"""Validates if the arguments passed are correct"""
|
||||
|
||||
incorrect_arg_msg = (
|
||||
"Some of the keys in `cache_config` are defined incorrectly. `{key}` should be {correct_value}` "
|
||||
"but found {found_value}"
|
||||
)
|
||||
|
||||
if self.batch_size <= 0:
|
||||
raise ValueError(
|
||||
incorrect_arg_msg.format(
|
||||
key="batch_size",
|
||||
correct_value="> 0",
|
||||
found_value=self.batch_size,
|
||||
),
|
||||
)
|
||||
|
||||
if self.max_cache_len <= 0:
|
||||
raise ValueError(
|
||||
incorrect_arg_msg.format(
|
||||
key="max_cache_len",
|
||||
correct_value="> 0",
|
||||
found_value=self.max_cache_len,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class DynamicCache(Cache):
|
||||
"""
|
||||
A cache that grows dynamically as more tokens are generated. This is the default for generative models.
|
||||
|
||||
@@ -57,9 +57,11 @@ if is_torch_available():
|
||||
QuantoQuantizedCache,
|
||||
SlidingWindowCache,
|
||||
StaticCache,
|
||||
StaticCacheConfig,
|
||||
)
|
||||
|
||||
NEEDS_CACHE_CONFIG["quantized"] = QuantizedCacheConfig
|
||||
NEEDS_CACHE_CONFIG["static"] = StaticCacheConfig
|
||||
NEED_SETUP_CACHE_CLASSES_MAPPING = {
|
||||
"static": StaticCache,
|
||||
"offloaded_static": OffloadedStaticCache,
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ..utils import _LazyModule
|
||||
from ..utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
|
||||
|
||||
|
||||
_import_structure = {
|
||||
@@ -98,6 +98,17 @@ _import_structure = {
|
||||
"quanto": ["replace_with_quanto_layers"],
|
||||
}
|
||||
|
||||
try:
|
||||
if not is_torch_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
_import_structure["executorch"] = [
|
||||
"TorchExportableModuleWithStaticCache",
|
||||
"convert_and_export_with_cache",
|
||||
]
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .aqlm import replace_with_aqlm_linear
|
||||
from .awq import (
|
||||
@@ -178,6 +189,15 @@ if TYPE_CHECKING:
|
||||
)
|
||||
from .peft import PeftAdapterMixin
|
||||
from .quanto import replace_with_quanto_layers
|
||||
|
||||
try:
|
||||
if not is_torch_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
from .executorch import TorchExportableModuleWithStaticCache, convert_and_export_with_cache
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
|
||||
159
src/transformers/integrations/executorch.py
Normal file
159
src/transformers/integrations/executorch.py
Normal file
@@ -0,0 +1,159 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
# the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations under the License.
|
||||
|
||||
import torch
|
||||
|
||||
from transformers import (
|
||||
PreTrainedModel,
|
||||
StaticCache,
|
||||
)
|
||||
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_3
|
||||
|
||||
|
||||
class TorchExportableModuleWithStaticCache(torch.nn.Module):
|
||||
"""
|
||||
A wrapper module designed to make a `PreTrainedModel` exportable with `torch.export`,
|
||||
specifically for use with static caching. This module ensures that the exported model
|
||||
is compatible with further lowering and execution in `ExecuTorch`.
|
||||
|
||||
Note:
|
||||
This class is specifically designed to support export process using `torch.export`
|
||||
in a way that ensures the model can be further lowered and run efficiently in `ExecuTorch`.
|
||||
"""
|
||||
|
||||
def __init__(self, model: PreTrainedModel):
|
||||
"""
|
||||
Initializes the wrapper module with the pretrained model.
|
||||
|
||||
Args:
|
||||
model (`PreTrainedModel`): The pretrained model to wrap. The model must have caching
|
||||
enabled and use a 'static' caching implementation.
|
||||
|
||||
Raises:
|
||||
AssertionError: If the pretrained model does not have caching enabled or if it does
|
||||
not use a 'static' caching implementation in `model.generation_config`.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
# Sanity checks
|
||||
if model.generation_config is None:
|
||||
raise AssertionError(
|
||||
"The model must have a generation config to be exported with static caching. "
|
||||
"Please set `generation_config`."
|
||||
)
|
||||
|
||||
if not model.generation_config.use_cache:
|
||||
raise AssertionError(
|
||||
"The model must have caching enabled to be exported with static caching. "
|
||||
"Please set `generation_config.use_cache=True`."
|
||||
)
|
||||
|
||||
if model.generation_config.cache_implementation != "static":
|
||||
raise AssertionError(
|
||||
"The model must use a 'static' caching implementation to be exported with static caching. "
|
||||
"Please set `generation_config.cache_implementation='static'`."
|
||||
)
|
||||
|
||||
self.model = model
|
||||
self.static_cache = StaticCache(
|
||||
config=self.model.config,
|
||||
batch_size=self.model.generation_config.cache_config.batch_size,
|
||||
max_cache_len=self.model.generation_config.cache_config.max_cache_len,
|
||||
dtype=self.model.config.torch_dtype,
|
||||
)
|
||||
self.is_causal = any("CausalLM" in arch for arch in self.model.config.architectures)
|
||||
if self.is_causal:
|
||||
causal_mask = torch.tril(
|
||||
torch.ones(
|
||||
self.static_cache.max_cache_len,
|
||||
self.static_cache.max_cache_len,
|
||||
dtype=torch.bool,
|
||||
)
|
||||
)
|
||||
self.register_buffer("mask", causal_mask, persistent=False)
|
||||
|
||||
def forward(self, input_ids: torch.Tensor, cache_position: torch.Tensor):
|
||||
"""
|
||||
Forward pass of the module, which is compatible with the ExecuTorch runtime.
|
||||
|
||||
Args:
|
||||
input_ids (`torch.Tensor`): Tensor representing current input token id to the module.
|
||||
cache_position (`torch.Tensor`): Tensor representing current input position in the cache.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Logits output from the model.
|
||||
|
||||
This forward adapter serves two primary purposes:
|
||||
|
||||
1. **Making the Model `torch.export`-Compatible**:
|
||||
The adapter hides unsupported objects, such as the `Cache`, from the graph inputs and outputs,
|
||||
enabling the model to be exportable using `torch.export` without encountering issues.
|
||||
|
||||
2. **Ensuring Compatibility with `ExecuTorch` runtime**:
|
||||
The adapter matches the model's forward signature with that in `executorch/extension/llm/runner`,
|
||||
ensuring that the exported model can be executed in `ExecuTorch` out-of-the-box.
|
||||
"""
|
||||
_, seqlen = input_ids.shape
|
||||
attn_mask = self.mask[cache_position, :seqlen] if self.is_causal else None
|
||||
outs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attn_mask,
|
||||
position_ids=cache_position.unsqueeze(0),
|
||||
cache_position=cache_position,
|
||||
past_key_values=self.static_cache,
|
||||
use_cache=True,
|
||||
)
|
||||
return outs.logits
|
||||
|
||||
|
||||
def convert_and_export_with_cache(
|
||||
model: PreTrainedModel,
|
||||
example_input_ids: torch.Tensor = None,
|
||||
example_cache_position: torch.Tensor = None,
|
||||
):
|
||||
"""
|
||||
Convert a `PreTrainedModel` into an exportable module and export it using `torch.export`,
|
||||
ensuring the exported model is compatible with `ExecuTorch`.
|
||||
|
||||
Args:
|
||||
model (`PreTrainedModel`): The pretrained model to be exported.
|
||||
example_input_ids (`torch.Tensor`): Example input token id used by `torch.export`.
|
||||
example_cache_position (`torch.Tensor`): Example current cache position used by `torch.export`.
|
||||
|
||||
Returns:
|
||||
Exported program (`torch.export.ExportedProgram`): The exported program generated via `torch.export`.
|
||||
"""
|
||||
|
||||
if not is_torch_greater_or_equal_than_2_3:
|
||||
raise ImportError("torch >= 2.3 is required.")
|
||||
|
||||
import torch.export._trace
|
||||
|
||||
with torch.no_grad():
|
||||
# TODO: The default inputs only work for text models. We need to add support for vision/audio models.
|
||||
example_input_ids = (
|
||||
example_input_ids if example_input_ids is not None else torch.tensor([[1]], dtype=torch.long)
|
||||
)
|
||||
example_cache_position = (
|
||||
example_cache_position if example_cache_position is not None else torch.tensor([0], dtype=torch.long)
|
||||
)
|
||||
|
||||
# Due to issue https://github.com/pytorch/pytorch/issues/128394, we need to switch to use an internal
|
||||
# export API and pre_dispatch=False. Switch to use the public API once the issue is included in 2.5 release.
|
||||
exported_program = torch.export._trace._export(
|
||||
TorchExportableModuleWithStaticCache(model),
|
||||
args=(example_input_ids,),
|
||||
kwargs={"cache_position": example_cache_position},
|
||||
pre_dispatch=False,
|
||||
strict=True,
|
||||
)
|
||||
return exported_program
|
||||
@@ -3223,6 +3223,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
adapter_kwargs = kwargs.pop("adapter_kwargs", {})
|
||||
adapter_name = kwargs.pop("adapter_name", "default")
|
||||
use_flash_attention_2 = kwargs.pop("use_flash_attention_2", False)
|
||||
generation_config = kwargs.pop("generation_config", None)
|
||||
|
||||
gguf_file = kwargs.pop("gguf_file", None)
|
||||
# Cache path to the GGUF file
|
||||
@@ -3998,7 +3999,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
model.eval()
|
||||
|
||||
# If it is a model with generation capabilities, attempt to load the generation config
|
||||
if model.can_generate() and pretrained_model_name_or_path is not None:
|
||||
if model.can_generate() and generation_config is not None:
|
||||
logger.info("The user-defined `generation_config` will be used to override the default generation config.")
|
||||
model.generation_config = model.generation_config.from_dict(generation_config.to_dict())
|
||||
elif model.can_generate() and pretrained_model_name_or_path is not None:
|
||||
try:
|
||||
model.generation_config = GenerationConfig.from_pretrained(
|
||||
pretrained_model_name_or_path,
|
||||
|
||||
@@ -513,6 +513,17 @@ class WhisperTimeStampLogitsProcessor(metaclass=DummyObject):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class TorchExportableModuleWithStaticCache(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
def convert_and_export_with_cache(*args, **kwargs):
|
||||
requires_backends(convert_and_export_with_cache, ["torch"])
|
||||
|
||||
|
||||
ROPE_INIT_FUNCTIONS = None
|
||||
|
||||
|
||||
|
||||
@@ -16,7 +16,6 @@
|
||||
import copy
|
||||
import unittest
|
||||
|
||||
from packaging import version
|
||||
from parameterized import parameterized
|
||||
|
||||
from transformers import set_seed
|
||||
@@ -35,7 +34,6 @@ if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
DynamicCache,
|
||||
@@ -44,7 +42,9 @@ if is_torch_available():
|
||||
LlamaConfig,
|
||||
SinkCache,
|
||||
StaticCache,
|
||||
convert_and_export_with_cache,
|
||||
)
|
||||
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_3
|
||||
|
||||
|
||||
@require_torch
|
||||
@@ -175,61 +175,54 @@ class CacheTest(unittest.TestCase):
|
||||
"""
|
||||
Tests that static cache works with `torch.export()`
|
||||
"""
|
||||
import torch
|
||||
|
||||
if version.parse(torch.__version__) < version.parse("2.3"):
|
||||
if not is_torch_greater_or_equal_than_2_3:
|
||||
self.skipTest(reason="This test requires torch >= 2.3 to run.")
|
||||
|
||||
set_seed(0)
|
||||
device = "cpu"
|
||||
dtype = torch.float32
|
||||
cache_implementation = "static"
|
||||
attn_implementation = "sdpa" # Export and ExecuTorch only works for SdpaAttention
|
||||
batch_size = 1
|
||||
|
||||
config = AutoConfig.from_pretrained(
|
||||
max_cache_len = 1234
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"google/gemma-2b",
|
||||
device_map=device,
|
||||
torch_dtype=dtype,
|
||||
use_cache=True,
|
||||
attn_implementation=attn_implementation,
|
||||
generation_config=GenerationConfig(
|
||||
use_cache=True,
|
||||
cache_implementation=cache_implementation,
|
||||
max_length=max_cache_len,
|
||||
cache_config={
|
||||
"batch_size": batch_size,
|
||||
"max_cache_len": max_cache_len,
|
||||
},
|
||||
),
|
||||
)
|
||||
m = AutoModelForCausalLM.from_pretrained(
|
||||
"google/gemma-2b",
|
||||
config=config,
|
||||
torch_dtype=dtype,
|
||||
attn_implementation="sdpa", # Export and ExecuTorch only works for SdpaAttention
|
||||
).to(device)
|
||||
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")
|
||||
inputs = tokenizer(["The best color is"], return_tensors="pt").to(device)["input_ids"]
|
||||
# Check if cache config is passed through correctly
|
||||
self.assertEqual(model.generation_config.use_cache, True)
|
||||
self.assertEqual(model.generation_config.cache_implementation, cache_implementation)
|
||||
self.assertEqual(model.generation_config.max_length, max_cache_len)
|
||||
self.assertTrue(model.generation_config.cache_config is not None)
|
||||
self.assertEqual(model.generation_config.cache_config.batch_size, batch_size)
|
||||
self.assertEqual(model.generation_config.cache_config.max_cache_len, max_cache_len)
|
||||
|
||||
class ExportatibleModelWithStaticCache(torch.nn.Module):
|
||||
def __init__(self, config, model):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.model = model
|
||||
self.static_cache = StaticCache(
|
||||
config=config, batch_size=batch_size, max_cache_len=config.max_length, device=device
|
||||
)
|
||||
exported_program = convert_and_export_with_cache(model)
|
||||
|
||||
def forward(self, tokens: torch.Tensor, input_pos: torch.Tensor):
|
||||
outs = self.model(
|
||||
input_ids=tokens,
|
||||
attention_mask=None,
|
||||
position_ids=input_pos.unsqueeze(0),
|
||||
cache_position=input_pos,
|
||||
past_key_values=self.static_cache,
|
||||
use_cache=True,
|
||||
)
|
||||
return outs.logits
|
||||
|
||||
set_seed(0)
|
||||
with torch.no_grad():
|
||||
import torch.export._trace
|
||||
from torch.export import ExportedProgram
|
||||
|
||||
model = ExportatibleModelWithStaticCache(config, m)
|
||||
# Due to issue https://github.com/pytorch/pytorch/issues/128394, we need to switch to use an internal
|
||||
# export API and pre_dispatch=False. Switch to use the public API once the issue is included in 2.4.1+ release.
|
||||
exported_program = torch.export._trace._export(
|
||||
model, args=(inputs,), kwargs={"input_pos": torch.arange(1)}, pre_dispatch=False, strict=True
|
||||
)
|
||||
self.assertTrue(isinstance(exported_program, ExportedProgram))
|
||||
# Check if the exported model is configured with the `StaticCache` correctly
|
||||
n_static_key_caches = n_static_value_caches = 0
|
||||
for buffer_name, buffer in exported_program.named_buffers():
|
||||
if buffer_name.startswith("static_cache.key_cache"):
|
||||
self.assertTrue(buffer.shape[0] == batch_size)
|
||||
self.assertTrue(buffer.shape[2] == max_cache_len)
|
||||
n_static_key_caches = n_static_key_caches + 1
|
||||
if buffer_name.startswith("static_cache.value_cache"):
|
||||
self.assertTrue(buffer.shape[0] == batch_size)
|
||||
self.assertTrue(buffer.shape[2] == max_cache_len)
|
||||
n_static_value_caches = n_static_value_caches + 1
|
||||
self.assertEqual(n_static_key_caches, model.config.num_hidden_layers)
|
||||
self.assertEqual(n_static_value_caches, model.config.num_hidden_layers)
|
||||
|
||||
|
||||
@require_torch_gpu
|
||||
|
||||
Reference in New Issue
Block a user