Offloaded cache: fix generate (#34921)
* fix cache impl * require_torch_gpu * fix mamba * fix copies
This commit is contained in:
committed by
GitHub
parent
57ca9e6d2f
commit
5e8c1d713d
@@ -1140,13 +1140,13 @@ class StaticCache(Cache):
|
||||
layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if max_batch_size is not None:
|
||||
if batch_size is not None:
|
||||
logger.warning_once(
|
||||
f"The 'max_batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in "
|
||||
"v4.46. Use the more precisely named 'batch_size' argument instead."
|
||||
f"The 'batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in "
|
||||
"v4.49. Use the more precisely named 'max_batch_size' argument instead."
|
||||
)
|
||||
|
||||
self.batch_size = batch_size or max_batch_size
|
||||
self.max_batch_size = batch_size or max_batch_size
|
||||
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
|
||||
|
||||
# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
|
||||
@@ -1254,6 +1254,14 @@ class StaticCache(Cache):
|
||||
self.key_cache[layer_idx].zero_()
|
||||
self.value_cache[layer_idx].zero_()
|
||||
|
||||
@property
|
||||
def batch_size(self):
|
||||
logger.warning_once(
|
||||
f"The 'batch_size' attribute of {self.__class__.__name__} is deprecated and will be removed in "
|
||||
"v4.49. Use the more precisely named 'self.max_batch_size' attribute instead."
|
||||
)
|
||||
return self.max_batch_size
|
||||
|
||||
|
||||
class SlidingWindowCache(StaticCache):
|
||||
"""
|
||||
@@ -1626,10 +1634,10 @@ class HybridCache(Cache):
|
||||
layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if max_batch_size is not None:
|
||||
if batch_size is not None:
|
||||
logger.warning_once(
|
||||
f"The 'max_batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in "
|
||||
"v4.46. Use the more precisely named 'batch_size' argument instead."
|
||||
f"The 'batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in "
|
||||
"v4.49. Use the more precisely named 'max_batch_size' argument instead."
|
||||
)
|
||||
if not hasattr(config, "sliding_window") or config.sliding_window is None:
|
||||
raise ValueError(
|
||||
@@ -1638,7 +1646,7 @@ class HybridCache(Cache):
|
||||
"config and it's not set to None."
|
||||
)
|
||||
self.max_cache_len = max_cache_len
|
||||
self.batch_size = batch_size or max_batch_size
|
||||
self.max_batch_size = batch_size or max_batch_size
|
||||
# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
|
||||
self.head_dim = (
|
||||
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
|
||||
@@ -1758,6 +1766,14 @@ class HybridCache(Cache):
|
||||
self.key_cache[layer_idx].zero_()
|
||||
self.value_cache[layer_idx].zero_()
|
||||
|
||||
@property
|
||||
def batch_size(self):
|
||||
logger.warning_once(
|
||||
f"The 'batch_size' attribute of {self.__class__.__name__} is deprecated and will be removed in "
|
||||
"v4.49. Use the more precisely named 'self.max_batch_size' attribute instead."
|
||||
)
|
||||
return self.max_batch_size
|
||||
|
||||
|
||||
class MambaCache:
|
||||
"""
|
||||
@@ -1815,20 +1831,20 @@ class MambaCache:
|
||||
device: Optional[Union[torch.device, str]] = None,
|
||||
max_batch_size: Optional[int] = None,
|
||||
):
|
||||
if max_batch_size is not None:
|
||||
if batch_size is not None:
|
||||
logger.warning_once(
|
||||
f"The 'max_batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in "
|
||||
"v4.46. Use the more precisely named 'batch_size' argument instead."
|
||||
f"The 'batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in "
|
||||
"v4.49. Use the more precisely named 'max_batch_size' argument instead."
|
||||
)
|
||||
self.dtype = dtype
|
||||
self.batch_size = batch_size or max_batch_size
|
||||
self.max_batch_size = batch_size or max_batch_size
|
||||
self.intermediate_size = config.intermediate_size
|
||||
self.ssm_state_size = config.state_size
|
||||
self.conv_kernel_size = config.conv_kernel
|
||||
|
||||
self.conv_states: torch.Tensor = torch.zeros(
|
||||
config.num_hidden_layers,
|
||||
self.batch_size,
|
||||
self.max_batch_size,
|
||||
self.intermediate_size,
|
||||
self.conv_kernel_size,
|
||||
device=device,
|
||||
@@ -1836,7 +1852,7 @@ class MambaCache:
|
||||
)
|
||||
self.ssm_states: torch.Tensor = torch.zeros(
|
||||
config.num_hidden_layers,
|
||||
self.batch_size,
|
||||
self.max_batch_size,
|
||||
self.intermediate_size,
|
||||
self.ssm_state_size,
|
||||
device=device,
|
||||
@@ -1866,6 +1882,14 @@ class MambaCache:
|
||||
self.conv_states.zero_()
|
||||
self.ssm_states.zero_()
|
||||
|
||||
@property
|
||||
def batch_size(self):
|
||||
logger.warning_once(
|
||||
f"The 'batch_size' attribute of {self.__class__.__name__} is deprecated and will be removed in "
|
||||
"v4.49. Use the more precisely named 'self.max_batch_size' attribute instead."
|
||||
)
|
||||
return self.max_batch_size
|
||||
|
||||
|
||||
class OffloadedStaticCache(StaticCache):
|
||||
"""
|
||||
@@ -1887,6 +1911,9 @@ class OffloadedStaticCache(StaticCache):
|
||||
The default `dtype` to use when initializing the cache.
|
||||
offload_device (`Union[str, torch.device]`, *optional*, defaults to `cpu`):
|
||||
The device to offload to. Defaults to CPU.
|
||||
layer_device_map (`Dict[int, Union[str, torch.device, int]]`, *optional*):
|
||||
Mapping between the layers and its device. This is required when you are manually initializing the cache and the model is splitted between differents gpus.
|
||||
You can know which layers mapped to which device by checking the associated device_map: `model.hf_device_map`.
|
||||
|
||||
Attributes:
|
||||
key_cache (`List[torch.Tensor]`):
|
||||
@@ -1933,10 +1960,11 @@ class OffloadedStaticCache(StaticCache):
|
||||
device: Union[str, torch.device],
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
offload_device: Union[str, torch.device] = torch.device("cpu"),
|
||||
layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
|
||||
) -> None:
|
||||
self.max_batch_size = max_batch_size
|
||||
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
|
||||
self.device = torch.device(device)
|
||||
self.device = torch.device(device) if layer_device_map is None else layer_device_map[0]
|
||||
self.offload_device = torch.device(offload_device)
|
||||
self.dtype = dtype if dtype is not None else torch.float32
|
||||
|
||||
@@ -1944,7 +1972,9 @@ class OffloadedStaticCache(StaticCache):
|
||||
head_dim = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
|
||||
|
||||
num_key_value_heads = (
|
||||
config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
|
||||
config.num_attention_heads
|
||||
if getattr(config, "num_key_value_heads", None) is None
|
||||
else config.num_key_value_heads
|
||||
)
|
||||
|
||||
cache_shape = (max_batch_size, num_key_value_heads, self.max_cache_len, head_dim)
|
||||
|
||||
@@ -72,7 +72,9 @@ if is_torch_available():
|
||||
"mamba": MambaCache,
|
||||
}
|
||||
QUANT_BACKEND_CLASSES_MAPPING = {"quanto": QuantoQuantizedCache, "HQQ": HQQQuantizedCache}
|
||||
ALL_CACHE_IMPLEMENTATIONS = list(NEED_SETUP_CACHE_CLASSES_MAPPING.keys()) + list(NEEDS_CACHE_CONFIG.keys())
|
||||
ALL_CACHE_IMPLEMENTATIONS = (
|
||||
list(NEED_SETUP_CACHE_CLASSES_MAPPING.keys()) + list(NEEDS_CACHE_CONFIG.keys()) + ["offloaded"]
|
||||
)
|
||||
|
||||
|
||||
class GenerationMode(ExplicitEnum):
|
||||
|
||||
@@ -1610,7 +1610,7 @@ class GenerationMixin:
|
||||
need_new_cache = (
|
||||
not hasattr(self, "_cache")
|
||||
or (not isinstance(cache_to_check, cache_cls))
|
||||
or cache_to_check.batch_size != batch_size
|
||||
or cache_to_check.max_batch_size != batch_size
|
||||
)
|
||||
if cache_implementation != "mamba":
|
||||
need_new_cache = need_new_cache or cache_to_check.max_cache_len < max_cache_len
|
||||
@@ -1666,7 +1666,7 @@ class GenerationMixin:
|
||||
|
||||
cache_kwargs = {
|
||||
"config": self.config.get_text_config(),
|
||||
"batch_size": batch_size,
|
||||
"max_batch_size": batch_size,
|
||||
"max_cache_len": max_cache_len,
|
||||
"device": device,
|
||||
"dtype": cache_dtype,
|
||||
|
||||
@@ -1880,6 +1880,32 @@ class GenerationTesterMixin:
|
||||
)
|
||||
)
|
||||
|
||||
@parameterized.expand([("offloaded",)]) # ("offloaded_static",) TODO: @raushan fixme in some models (eg T5)
|
||||
@require_torch_gpu
|
||||
@pytest.mark.generate
|
||||
def test_offloaded_cache_implementation(self, cache_implementation):
|
||||
"""Tests we can generate by indicating `cache_implementation` for each possible cache class"""
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if not model_class._supports_cache_class:
|
||||
self.skipTest(reason="This model does not support the new cache format")
|
||||
|
||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
generation_kwargs = {
|
||||
"max_new_tokens": 5,
|
||||
"use_cache": True,
|
||||
"cache_implementation": cache_implementation,
|
||||
}
|
||||
|
||||
legacy_results = model.generate(**generation_kwargs, **inputs_dict)
|
||||
|
||||
# Most cache classes have their own tests except for some that are tested here
|
||||
# The ones here do not need special treatment when passing `cache_implementation`
|
||||
# and are not bound to specific models only
|
||||
new_results = model.generate(**generation_kwargs, **inputs_dict)
|
||||
self.assertListEqual(legacy_results.tolist(), new_results.tolist())
|
||||
|
||||
@pytest.mark.generate
|
||||
def test_generate_with_static_cache(self):
|
||||
"""
|
||||
|
||||
@@ -16,7 +16,9 @@
|
||||
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
from parameterized import parameterized
|
||||
|
||||
from transformers import (
|
||||
AutoProcessor,
|
||||
@@ -365,6 +367,12 @@ class MllamaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTester
|
||||
def test_model_parallelism(self):
|
||||
pass
|
||||
|
||||
@parameterized.expand([("offloaded",)])
|
||||
@pytest.mark.generate
|
||||
@unittest.skip(reason="Offloaded cache seems to not work with mllama's kv cache type")
|
||||
def test_offloaded_cache_implementation(self, cache_implementation):
|
||||
pass
|
||||
|
||||
def test_generate_text_only_with_cache(self):
|
||||
"""
|
||||
Tests that our cached generation with text-only inputs works. When mllama was introduced, this feature
|
||||
|
||||
@@ -567,6 +567,12 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
def test_generate_with_head_masking(self):
|
||||
pass
|
||||
|
||||
@parameterized.expand([("offloaded",)])
|
||||
@pytest.mark.generate
|
||||
@unittest.skip(reason="Whisper doesnt work with offloaded cache implementation yet")
|
||||
def test_offloaded_cache_implementation(self, cache_implementation):
|
||||
pass
|
||||
|
||||
@require_torch_fp16
|
||||
def test_generate_fp16(self):
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs()
|
||||
|
||||
Reference in New Issue
Block a user