Offloaded cache: fix generate (#34921)

* fix cache impl

* require_torch_gpu

* fix mamba

* fix copies
This commit is contained in:
Raushan Turganbay
2024-11-28 15:05:56 +01:00
committed by GitHub
parent 57ca9e6d2f
commit 5e8c1d713d
6 changed files with 91 additions and 19 deletions

View File

@@ -1140,13 +1140,13 @@ class StaticCache(Cache):
layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None, layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
if max_batch_size is not None: if batch_size is not None:
logger.warning_once( logger.warning_once(
f"The 'max_batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in " f"The 'batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in "
"v4.46. Use the more precisely named 'batch_size' argument instead." "v4.49. Use the more precisely named 'max_batch_size' argument instead."
) )
self.batch_size = batch_size or max_batch_size self.max_batch_size = batch_size or max_batch_size
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
@@ -1254,6 +1254,14 @@ class StaticCache(Cache):
self.key_cache[layer_idx].zero_() self.key_cache[layer_idx].zero_()
self.value_cache[layer_idx].zero_() self.value_cache[layer_idx].zero_()
@property
def batch_size(self):
logger.warning_once(
f"The 'batch_size' attribute of {self.__class__.__name__} is deprecated and will be removed in "
"v4.49. Use the more precisely named 'self.max_batch_size' attribute instead."
)
return self.max_batch_size
class SlidingWindowCache(StaticCache): class SlidingWindowCache(StaticCache):
""" """
@@ -1626,10 +1634,10 @@ class HybridCache(Cache):
layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None, layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
if max_batch_size is not None: if batch_size is not None:
logger.warning_once( logger.warning_once(
f"The 'max_batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in " f"The 'batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in "
"v4.46. Use the more precisely named 'batch_size' argument instead." "v4.49. Use the more precisely named 'max_batch_size' argument instead."
) )
if not hasattr(config, "sliding_window") or config.sliding_window is None: if not hasattr(config, "sliding_window") or config.sliding_window is None:
raise ValueError( raise ValueError(
@@ -1638,7 +1646,7 @@ class HybridCache(Cache):
"config and it's not set to None." "config and it's not set to None."
) )
self.max_cache_len = max_cache_len self.max_cache_len = max_cache_len
self.batch_size = batch_size or max_batch_size self.max_batch_size = batch_size or max_batch_size
# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
self.head_dim = ( self.head_dim = (
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
@@ -1758,6 +1766,14 @@ class HybridCache(Cache):
self.key_cache[layer_idx].zero_() self.key_cache[layer_idx].zero_()
self.value_cache[layer_idx].zero_() self.value_cache[layer_idx].zero_()
@property
def batch_size(self):
logger.warning_once(
f"The 'batch_size' attribute of {self.__class__.__name__} is deprecated and will be removed in "
"v4.49. Use the more precisely named 'self.max_batch_size' attribute instead."
)
return self.max_batch_size
class MambaCache: class MambaCache:
""" """
@@ -1815,20 +1831,20 @@ class MambaCache:
device: Optional[Union[torch.device, str]] = None, device: Optional[Union[torch.device, str]] = None,
max_batch_size: Optional[int] = None, max_batch_size: Optional[int] = None,
): ):
if max_batch_size is not None: if batch_size is not None:
logger.warning_once( logger.warning_once(
f"The 'max_batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in " f"The 'batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in "
"v4.46. Use the more precisely named 'batch_size' argument instead." "v4.49. Use the more precisely named 'max_batch_size' argument instead."
) )
self.dtype = dtype self.dtype = dtype
self.batch_size = batch_size or max_batch_size self.max_batch_size = batch_size or max_batch_size
self.intermediate_size = config.intermediate_size self.intermediate_size = config.intermediate_size
self.ssm_state_size = config.state_size self.ssm_state_size = config.state_size
self.conv_kernel_size = config.conv_kernel self.conv_kernel_size = config.conv_kernel
self.conv_states: torch.Tensor = torch.zeros( self.conv_states: torch.Tensor = torch.zeros(
config.num_hidden_layers, config.num_hidden_layers,
self.batch_size, self.max_batch_size,
self.intermediate_size, self.intermediate_size,
self.conv_kernel_size, self.conv_kernel_size,
device=device, device=device,
@@ -1836,7 +1852,7 @@ class MambaCache:
) )
self.ssm_states: torch.Tensor = torch.zeros( self.ssm_states: torch.Tensor = torch.zeros(
config.num_hidden_layers, config.num_hidden_layers,
self.batch_size, self.max_batch_size,
self.intermediate_size, self.intermediate_size,
self.ssm_state_size, self.ssm_state_size,
device=device, device=device,
@@ -1866,6 +1882,14 @@ class MambaCache:
self.conv_states.zero_() self.conv_states.zero_()
self.ssm_states.zero_() self.ssm_states.zero_()
@property
def batch_size(self):
logger.warning_once(
f"The 'batch_size' attribute of {self.__class__.__name__} is deprecated and will be removed in "
"v4.49. Use the more precisely named 'self.max_batch_size' attribute instead."
)
return self.max_batch_size
class OffloadedStaticCache(StaticCache): class OffloadedStaticCache(StaticCache):
""" """
@@ -1887,6 +1911,9 @@ class OffloadedStaticCache(StaticCache):
The default `dtype` to use when initializing the cache. The default `dtype` to use when initializing the cache.
offload_device (`Union[str, torch.device]`, *optional*, defaults to `cpu`): offload_device (`Union[str, torch.device]`, *optional*, defaults to `cpu`):
The device to offload to. Defaults to CPU. The device to offload to. Defaults to CPU.
layer_device_map (`Dict[int, Union[str, torch.device, int]]`, *optional*):
Mapping between the layers and its device. This is required when you are manually initializing the cache and the model is splitted between differents gpus.
You can know which layers mapped to which device by checking the associated device_map: `model.hf_device_map`.
Attributes: Attributes:
key_cache (`List[torch.Tensor]`): key_cache (`List[torch.Tensor]`):
@@ -1933,10 +1960,11 @@ class OffloadedStaticCache(StaticCache):
device: Union[str, torch.device], device: Union[str, torch.device],
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
offload_device: Union[str, torch.device] = torch.device("cpu"), offload_device: Union[str, torch.device] = torch.device("cpu"),
layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
) -> None: ) -> None:
self.max_batch_size = max_batch_size self.max_batch_size = max_batch_size
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
self.device = torch.device(device) self.device = torch.device(device) if layer_device_map is None else layer_device_map[0]
self.offload_device = torch.device(offload_device) self.offload_device = torch.device(offload_device)
self.dtype = dtype if dtype is not None else torch.float32 self.dtype = dtype if dtype is not None else torch.float32
@@ -1944,7 +1972,9 @@ class OffloadedStaticCache(StaticCache):
head_dim = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads head_dim = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
num_key_value_heads = ( num_key_value_heads = (
config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads config.num_attention_heads
if getattr(config, "num_key_value_heads", None) is None
else config.num_key_value_heads
) )
cache_shape = (max_batch_size, num_key_value_heads, self.max_cache_len, head_dim) cache_shape = (max_batch_size, num_key_value_heads, self.max_cache_len, head_dim)

View File

@@ -72,7 +72,9 @@ if is_torch_available():
"mamba": MambaCache, "mamba": MambaCache,
} }
QUANT_BACKEND_CLASSES_MAPPING = {"quanto": QuantoQuantizedCache, "HQQ": HQQQuantizedCache} QUANT_BACKEND_CLASSES_MAPPING = {"quanto": QuantoQuantizedCache, "HQQ": HQQQuantizedCache}
ALL_CACHE_IMPLEMENTATIONS = list(NEED_SETUP_CACHE_CLASSES_MAPPING.keys()) + list(NEEDS_CACHE_CONFIG.keys()) ALL_CACHE_IMPLEMENTATIONS = (
list(NEED_SETUP_CACHE_CLASSES_MAPPING.keys()) + list(NEEDS_CACHE_CONFIG.keys()) + ["offloaded"]
)
class GenerationMode(ExplicitEnum): class GenerationMode(ExplicitEnum):

View File

@@ -1610,7 +1610,7 @@ class GenerationMixin:
need_new_cache = ( need_new_cache = (
not hasattr(self, "_cache") not hasattr(self, "_cache")
or (not isinstance(cache_to_check, cache_cls)) or (not isinstance(cache_to_check, cache_cls))
or cache_to_check.batch_size != batch_size or cache_to_check.max_batch_size != batch_size
) )
if cache_implementation != "mamba": if cache_implementation != "mamba":
need_new_cache = need_new_cache or cache_to_check.max_cache_len < max_cache_len need_new_cache = need_new_cache or cache_to_check.max_cache_len < max_cache_len
@@ -1666,7 +1666,7 @@ class GenerationMixin:
cache_kwargs = { cache_kwargs = {
"config": self.config.get_text_config(), "config": self.config.get_text_config(),
"batch_size": batch_size, "max_batch_size": batch_size,
"max_cache_len": max_cache_len, "max_cache_len": max_cache_len,
"device": device, "device": device,
"dtype": cache_dtype, "dtype": cache_dtype,

View File

@@ -1880,6 +1880,32 @@ class GenerationTesterMixin:
) )
) )
@parameterized.expand([("offloaded",)]) # ("offloaded_static",) TODO: @raushan fixme in some models (eg T5)
@require_torch_gpu
@pytest.mark.generate
def test_offloaded_cache_implementation(self, cache_implementation):
"""Tests we can generate by indicating `cache_implementation` for each possible cache class"""
for model_class in self.all_generative_model_classes:
if not model_class._supports_cache_class:
self.skipTest(reason="This model does not support the new cache format")
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
model = model_class(config).to(torch_device).eval()
generation_kwargs = {
"max_new_tokens": 5,
"use_cache": True,
"cache_implementation": cache_implementation,
}
legacy_results = model.generate(**generation_kwargs, **inputs_dict)
# Most cache classes have their own tests except for some that are tested here
# The ones here do not need special treatment when passing `cache_implementation`
# and are not bound to specific models only
new_results = model.generate(**generation_kwargs, **inputs_dict)
self.assertListEqual(legacy_results.tolist(), new_results.tolist())
@pytest.mark.generate @pytest.mark.generate
def test_generate_with_static_cache(self): def test_generate_with_static_cache(self):
""" """

View File

@@ -16,7 +16,9 @@
import unittest import unittest
import pytest
import requests import requests
from parameterized import parameterized
from transformers import ( from transformers import (
AutoProcessor, AutoProcessor,
@@ -365,6 +367,12 @@ class MllamaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTester
def test_model_parallelism(self): def test_model_parallelism(self):
pass pass
@parameterized.expand([("offloaded",)])
@pytest.mark.generate
@unittest.skip(reason="Offloaded cache seems to not work with mllama's kv cache type")
def test_offloaded_cache_implementation(self, cache_implementation):
pass
def test_generate_text_only_with_cache(self): def test_generate_text_only_with_cache(self):
""" """
Tests that our cached generation with text-only inputs works. When mllama was introduced, this feature Tests that our cached generation with text-only inputs works. When mllama was introduced, this feature

View File

@@ -567,6 +567,12 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
def test_generate_with_head_masking(self): def test_generate_with_head_masking(self):
pass pass
@parameterized.expand([("offloaded",)])
@pytest.mark.generate
@unittest.skip(reason="Whisper doesnt work with offloaded cache implementation yet")
def test_offloaded_cache_implementation(self, cache_implementation):
pass
@require_torch_fp16 @require_torch_fp16
def test_generate_fp16(self): def test_generate_fp16(self):
config, input_dict = self.model_tester.prepare_config_and_inputs() config, input_dict = self.model_tester.prepare_config_and_inputs()