Offloaded cache: fix generate (#34921)

* fix cache impl

* require_torch_gpu

* fix mamba

* fix copies
This commit is contained in:
Raushan Turganbay
2024-11-28 15:05:56 +01:00
committed by GitHub
parent 57ca9e6d2f
commit 5e8c1d713d
6 changed files with 91 additions and 19 deletions

View File

@@ -1140,13 +1140,13 @@ class StaticCache(Cache):
layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
) -> None:
super().__init__()
if max_batch_size is not None:
if batch_size is not None:
logger.warning_once(
f"The 'max_batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in "
"v4.46. Use the more precisely named 'batch_size' argument instead."
f"The 'batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in "
"v4.49. Use the more precisely named 'max_batch_size' argument instead."
)
self.batch_size = batch_size or max_batch_size
self.max_batch_size = batch_size or max_batch_size
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
@@ -1254,6 +1254,14 @@ class StaticCache(Cache):
self.key_cache[layer_idx].zero_()
self.value_cache[layer_idx].zero_()
@property
def batch_size(self):
logger.warning_once(
f"The 'batch_size' attribute of {self.__class__.__name__} is deprecated and will be removed in "
"v4.49. Use the more precisely named 'self.max_batch_size' attribute instead."
)
return self.max_batch_size
class SlidingWindowCache(StaticCache):
"""
@@ -1626,10 +1634,10 @@ class HybridCache(Cache):
layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
) -> None:
super().__init__()
if max_batch_size is not None:
if batch_size is not None:
logger.warning_once(
f"The 'max_batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in "
"v4.46. Use the more precisely named 'batch_size' argument instead."
f"The 'batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in "
"v4.49. Use the more precisely named 'max_batch_size' argument instead."
)
if not hasattr(config, "sliding_window") or config.sliding_window is None:
raise ValueError(
@@ -1638,7 +1646,7 @@ class HybridCache(Cache):
"config and it's not set to None."
)
self.max_cache_len = max_cache_len
self.batch_size = batch_size or max_batch_size
self.max_batch_size = batch_size or max_batch_size
# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
self.head_dim = (
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
@@ -1758,6 +1766,14 @@ class HybridCache(Cache):
self.key_cache[layer_idx].zero_()
self.value_cache[layer_idx].zero_()
@property
def batch_size(self):
logger.warning_once(
f"The 'batch_size' attribute of {self.__class__.__name__} is deprecated and will be removed in "
"v4.49. Use the more precisely named 'self.max_batch_size' attribute instead."
)
return self.max_batch_size
class MambaCache:
"""
@@ -1815,20 +1831,20 @@ class MambaCache:
device: Optional[Union[torch.device, str]] = None,
max_batch_size: Optional[int] = None,
):
if max_batch_size is not None:
if batch_size is not None:
logger.warning_once(
f"The 'max_batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in "
"v4.46. Use the more precisely named 'batch_size' argument instead."
f"The 'batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in "
"v4.49. Use the more precisely named 'max_batch_size' argument instead."
)
self.dtype = dtype
self.batch_size = batch_size or max_batch_size
self.max_batch_size = batch_size or max_batch_size
self.intermediate_size = config.intermediate_size
self.ssm_state_size = config.state_size
self.conv_kernel_size = config.conv_kernel
self.conv_states: torch.Tensor = torch.zeros(
config.num_hidden_layers,
self.batch_size,
self.max_batch_size,
self.intermediate_size,
self.conv_kernel_size,
device=device,
@@ -1836,7 +1852,7 @@ class MambaCache:
)
self.ssm_states: torch.Tensor = torch.zeros(
config.num_hidden_layers,
self.batch_size,
self.max_batch_size,
self.intermediate_size,
self.ssm_state_size,
device=device,
@@ -1866,6 +1882,14 @@ class MambaCache:
self.conv_states.zero_()
self.ssm_states.zero_()
@property
def batch_size(self):
logger.warning_once(
f"The 'batch_size' attribute of {self.__class__.__name__} is deprecated and will be removed in "
"v4.49. Use the more precisely named 'self.max_batch_size' attribute instead."
)
return self.max_batch_size
class OffloadedStaticCache(StaticCache):
"""
@@ -1887,6 +1911,9 @@ class OffloadedStaticCache(StaticCache):
The default `dtype` to use when initializing the cache.
offload_device (`Union[str, torch.device]`, *optional*, defaults to `cpu`):
The device to offload to. Defaults to CPU.
layer_device_map (`Dict[int, Union[str, torch.device, int]]`, *optional*):
Mapping between the layers and its device. This is required when you are manually initializing the cache and the model is splitted between differents gpus.
You can know which layers mapped to which device by checking the associated device_map: `model.hf_device_map`.
Attributes:
key_cache (`List[torch.Tensor]`):
@@ -1933,10 +1960,11 @@ class OffloadedStaticCache(StaticCache):
device: Union[str, torch.device],
dtype: Optional[torch.dtype] = None,
offload_device: Union[str, torch.device] = torch.device("cpu"),
layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
) -> None:
self.max_batch_size = max_batch_size
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
self.device = torch.device(device)
self.device = torch.device(device) if layer_device_map is None else layer_device_map[0]
self.offload_device = torch.device(offload_device)
self.dtype = dtype if dtype is not None else torch.float32
@@ -1944,7 +1972,9 @@ class OffloadedStaticCache(StaticCache):
head_dim = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
num_key_value_heads = (
config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
config.num_attention_heads
if getattr(config, "num_key_value_heads", None) is None
else config.num_key_value_heads
)
cache_shape = (max_batch_size, num_key_value_heads, self.max_cache_len, head_dim)

View File

@@ -72,7 +72,9 @@ if is_torch_available():
"mamba": MambaCache,
}
QUANT_BACKEND_CLASSES_MAPPING = {"quanto": QuantoQuantizedCache, "HQQ": HQQQuantizedCache}
ALL_CACHE_IMPLEMENTATIONS = list(NEED_SETUP_CACHE_CLASSES_MAPPING.keys()) + list(NEEDS_CACHE_CONFIG.keys())
ALL_CACHE_IMPLEMENTATIONS = (
list(NEED_SETUP_CACHE_CLASSES_MAPPING.keys()) + list(NEEDS_CACHE_CONFIG.keys()) + ["offloaded"]
)
class GenerationMode(ExplicitEnum):

View File

@@ -1610,7 +1610,7 @@ class GenerationMixin:
need_new_cache = (
not hasattr(self, "_cache")
or (not isinstance(cache_to_check, cache_cls))
or cache_to_check.batch_size != batch_size
or cache_to_check.max_batch_size != batch_size
)
if cache_implementation != "mamba":
need_new_cache = need_new_cache or cache_to_check.max_cache_len < max_cache_len
@@ -1666,7 +1666,7 @@ class GenerationMixin:
cache_kwargs = {
"config": self.config.get_text_config(),
"batch_size": batch_size,
"max_batch_size": batch_size,
"max_cache_len": max_cache_len,
"device": device,
"dtype": cache_dtype,

View File

@@ -1880,6 +1880,32 @@ class GenerationTesterMixin:
)
)
@parameterized.expand([("offloaded",)]) # ("offloaded_static",) TODO: @raushan fixme in some models (eg T5)
@require_torch_gpu
@pytest.mark.generate
def test_offloaded_cache_implementation(self, cache_implementation):
"""Tests we can generate by indicating `cache_implementation` for each possible cache class"""
for model_class in self.all_generative_model_classes:
if not model_class._supports_cache_class:
self.skipTest(reason="This model does not support the new cache format")
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
model = model_class(config).to(torch_device).eval()
generation_kwargs = {
"max_new_tokens": 5,
"use_cache": True,
"cache_implementation": cache_implementation,
}
legacy_results = model.generate(**generation_kwargs, **inputs_dict)
# Most cache classes have their own tests except for some that are tested here
# The ones here do not need special treatment when passing `cache_implementation`
# and are not bound to specific models only
new_results = model.generate(**generation_kwargs, **inputs_dict)
self.assertListEqual(legacy_results.tolist(), new_results.tolist())
@pytest.mark.generate
def test_generate_with_static_cache(self):
"""

View File

@@ -16,7 +16,9 @@
import unittest
import pytest
import requests
from parameterized import parameterized
from transformers import (
AutoProcessor,
@@ -365,6 +367,12 @@ class MllamaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTester
def test_model_parallelism(self):
pass
@parameterized.expand([("offloaded",)])
@pytest.mark.generate
@unittest.skip(reason="Offloaded cache seems to not work with mllama's kv cache type")
def test_offloaded_cache_implementation(self, cache_implementation):
pass
def test_generate_text_only_with_cache(self):
"""
Tests that our cached generation with text-only inputs works. When mllama was introduced, this feature

View File

@@ -567,6 +567,12 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
def test_generate_with_head_masking(self):
pass
@parameterized.expand([("offloaded",)])
@pytest.mark.generate
@unittest.skip(reason="Whisper doesnt work with offloaded cache implementation yet")
def test_offloaded_cache_implementation(self, cache_implementation):
pass
@require_torch_fp16
def test_generate_fp16(self):
config, input_dict = self.model_tester.prepare_config_and_inputs()