Offloaded cache: fix generate (#34921)
* fix cache impl * require_torch_gpu * fix mamba * fix copies
This commit is contained in:
committed by
GitHub
parent
57ca9e6d2f
commit
5e8c1d713d
@@ -1140,13 +1140,13 @@ class StaticCache(Cache):
|
|||||||
layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
|
layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if max_batch_size is not None:
|
if batch_size is not None:
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
f"The 'max_batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in "
|
f"The 'batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in "
|
||||||
"v4.46. Use the more precisely named 'batch_size' argument instead."
|
"v4.49. Use the more precisely named 'max_batch_size' argument instead."
|
||||||
)
|
)
|
||||||
|
|
||||||
self.batch_size = batch_size or max_batch_size
|
self.max_batch_size = batch_size or max_batch_size
|
||||||
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
|
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
|
||||||
|
|
||||||
# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
|
# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
|
||||||
@@ -1254,6 +1254,14 @@ class StaticCache(Cache):
|
|||||||
self.key_cache[layer_idx].zero_()
|
self.key_cache[layer_idx].zero_()
|
||||||
self.value_cache[layer_idx].zero_()
|
self.value_cache[layer_idx].zero_()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def batch_size(self):
|
||||||
|
logger.warning_once(
|
||||||
|
f"The 'batch_size' attribute of {self.__class__.__name__} is deprecated and will be removed in "
|
||||||
|
"v4.49. Use the more precisely named 'self.max_batch_size' attribute instead."
|
||||||
|
)
|
||||||
|
return self.max_batch_size
|
||||||
|
|
||||||
|
|
||||||
class SlidingWindowCache(StaticCache):
|
class SlidingWindowCache(StaticCache):
|
||||||
"""
|
"""
|
||||||
@@ -1626,10 +1634,10 @@ class HybridCache(Cache):
|
|||||||
layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
|
layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if max_batch_size is not None:
|
if batch_size is not None:
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
f"The 'max_batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in "
|
f"The 'batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in "
|
||||||
"v4.46. Use the more precisely named 'batch_size' argument instead."
|
"v4.49. Use the more precisely named 'max_batch_size' argument instead."
|
||||||
)
|
)
|
||||||
if not hasattr(config, "sliding_window") or config.sliding_window is None:
|
if not hasattr(config, "sliding_window") or config.sliding_window is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -1638,7 +1646,7 @@ class HybridCache(Cache):
|
|||||||
"config and it's not set to None."
|
"config and it's not set to None."
|
||||||
)
|
)
|
||||||
self.max_cache_len = max_cache_len
|
self.max_cache_len = max_cache_len
|
||||||
self.batch_size = batch_size or max_batch_size
|
self.max_batch_size = batch_size or max_batch_size
|
||||||
# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
|
# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
|
||||||
self.head_dim = (
|
self.head_dim = (
|
||||||
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
|
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
|
||||||
@@ -1758,6 +1766,14 @@ class HybridCache(Cache):
|
|||||||
self.key_cache[layer_idx].zero_()
|
self.key_cache[layer_idx].zero_()
|
||||||
self.value_cache[layer_idx].zero_()
|
self.value_cache[layer_idx].zero_()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def batch_size(self):
|
||||||
|
logger.warning_once(
|
||||||
|
f"The 'batch_size' attribute of {self.__class__.__name__} is deprecated and will be removed in "
|
||||||
|
"v4.49. Use the more precisely named 'self.max_batch_size' attribute instead."
|
||||||
|
)
|
||||||
|
return self.max_batch_size
|
||||||
|
|
||||||
|
|
||||||
class MambaCache:
|
class MambaCache:
|
||||||
"""
|
"""
|
||||||
@@ -1815,20 +1831,20 @@ class MambaCache:
|
|||||||
device: Optional[Union[torch.device, str]] = None,
|
device: Optional[Union[torch.device, str]] = None,
|
||||||
max_batch_size: Optional[int] = None,
|
max_batch_size: Optional[int] = None,
|
||||||
):
|
):
|
||||||
if max_batch_size is not None:
|
if batch_size is not None:
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
f"The 'max_batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in "
|
f"The 'batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in "
|
||||||
"v4.46. Use the more precisely named 'batch_size' argument instead."
|
"v4.49. Use the more precisely named 'max_batch_size' argument instead."
|
||||||
)
|
)
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
self.batch_size = batch_size or max_batch_size
|
self.max_batch_size = batch_size or max_batch_size
|
||||||
self.intermediate_size = config.intermediate_size
|
self.intermediate_size = config.intermediate_size
|
||||||
self.ssm_state_size = config.state_size
|
self.ssm_state_size = config.state_size
|
||||||
self.conv_kernel_size = config.conv_kernel
|
self.conv_kernel_size = config.conv_kernel
|
||||||
|
|
||||||
self.conv_states: torch.Tensor = torch.zeros(
|
self.conv_states: torch.Tensor = torch.zeros(
|
||||||
config.num_hidden_layers,
|
config.num_hidden_layers,
|
||||||
self.batch_size,
|
self.max_batch_size,
|
||||||
self.intermediate_size,
|
self.intermediate_size,
|
||||||
self.conv_kernel_size,
|
self.conv_kernel_size,
|
||||||
device=device,
|
device=device,
|
||||||
@@ -1836,7 +1852,7 @@ class MambaCache:
|
|||||||
)
|
)
|
||||||
self.ssm_states: torch.Tensor = torch.zeros(
|
self.ssm_states: torch.Tensor = torch.zeros(
|
||||||
config.num_hidden_layers,
|
config.num_hidden_layers,
|
||||||
self.batch_size,
|
self.max_batch_size,
|
||||||
self.intermediate_size,
|
self.intermediate_size,
|
||||||
self.ssm_state_size,
|
self.ssm_state_size,
|
||||||
device=device,
|
device=device,
|
||||||
@@ -1866,6 +1882,14 @@ class MambaCache:
|
|||||||
self.conv_states.zero_()
|
self.conv_states.zero_()
|
||||||
self.ssm_states.zero_()
|
self.ssm_states.zero_()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def batch_size(self):
|
||||||
|
logger.warning_once(
|
||||||
|
f"The 'batch_size' attribute of {self.__class__.__name__} is deprecated and will be removed in "
|
||||||
|
"v4.49. Use the more precisely named 'self.max_batch_size' attribute instead."
|
||||||
|
)
|
||||||
|
return self.max_batch_size
|
||||||
|
|
||||||
|
|
||||||
class OffloadedStaticCache(StaticCache):
|
class OffloadedStaticCache(StaticCache):
|
||||||
"""
|
"""
|
||||||
@@ -1887,6 +1911,9 @@ class OffloadedStaticCache(StaticCache):
|
|||||||
The default `dtype` to use when initializing the cache.
|
The default `dtype` to use when initializing the cache.
|
||||||
offload_device (`Union[str, torch.device]`, *optional*, defaults to `cpu`):
|
offload_device (`Union[str, torch.device]`, *optional*, defaults to `cpu`):
|
||||||
The device to offload to. Defaults to CPU.
|
The device to offload to. Defaults to CPU.
|
||||||
|
layer_device_map (`Dict[int, Union[str, torch.device, int]]`, *optional*):
|
||||||
|
Mapping between the layers and its device. This is required when you are manually initializing the cache and the model is splitted between differents gpus.
|
||||||
|
You can know which layers mapped to which device by checking the associated device_map: `model.hf_device_map`.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
key_cache (`List[torch.Tensor]`):
|
key_cache (`List[torch.Tensor]`):
|
||||||
@@ -1933,10 +1960,11 @@ class OffloadedStaticCache(StaticCache):
|
|||||||
device: Union[str, torch.device],
|
device: Union[str, torch.device],
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
offload_device: Union[str, torch.device] = torch.device("cpu"),
|
offload_device: Union[str, torch.device] = torch.device("cpu"),
|
||||||
|
layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.max_batch_size = max_batch_size
|
self.max_batch_size = max_batch_size
|
||||||
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
|
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
|
||||||
self.device = torch.device(device)
|
self.device = torch.device(device) if layer_device_map is None else layer_device_map[0]
|
||||||
self.offload_device = torch.device(offload_device)
|
self.offload_device = torch.device(offload_device)
|
||||||
self.dtype = dtype if dtype is not None else torch.float32
|
self.dtype = dtype if dtype is not None else torch.float32
|
||||||
|
|
||||||
@@ -1944,7 +1972,9 @@ class OffloadedStaticCache(StaticCache):
|
|||||||
head_dim = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
|
head_dim = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
|
||||||
|
|
||||||
num_key_value_heads = (
|
num_key_value_heads = (
|
||||||
config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
|
config.num_attention_heads
|
||||||
|
if getattr(config, "num_key_value_heads", None) is None
|
||||||
|
else config.num_key_value_heads
|
||||||
)
|
)
|
||||||
|
|
||||||
cache_shape = (max_batch_size, num_key_value_heads, self.max_cache_len, head_dim)
|
cache_shape = (max_batch_size, num_key_value_heads, self.max_cache_len, head_dim)
|
||||||
|
|||||||
@@ -72,7 +72,9 @@ if is_torch_available():
|
|||||||
"mamba": MambaCache,
|
"mamba": MambaCache,
|
||||||
}
|
}
|
||||||
QUANT_BACKEND_CLASSES_MAPPING = {"quanto": QuantoQuantizedCache, "HQQ": HQQQuantizedCache}
|
QUANT_BACKEND_CLASSES_MAPPING = {"quanto": QuantoQuantizedCache, "HQQ": HQQQuantizedCache}
|
||||||
ALL_CACHE_IMPLEMENTATIONS = list(NEED_SETUP_CACHE_CLASSES_MAPPING.keys()) + list(NEEDS_CACHE_CONFIG.keys())
|
ALL_CACHE_IMPLEMENTATIONS = (
|
||||||
|
list(NEED_SETUP_CACHE_CLASSES_MAPPING.keys()) + list(NEEDS_CACHE_CONFIG.keys()) + ["offloaded"]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class GenerationMode(ExplicitEnum):
|
class GenerationMode(ExplicitEnum):
|
||||||
|
|||||||
@@ -1610,7 +1610,7 @@ class GenerationMixin:
|
|||||||
need_new_cache = (
|
need_new_cache = (
|
||||||
not hasattr(self, "_cache")
|
not hasattr(self, "_cache")
|
||||||
or (not isinstance(cache_to_check, cache_cls))
|
or (not isinstance(cache_to_check, cache_cls))
|
||||||
or cache_to_check.batch_size != batch_size
|
or cache_to_check.max_batch_size != batch_size
|
||||||
)
|
)
|
||||||
if cache_implementation != "mamba":
|
if cache_implementation != "mamba":
|
||||||
need_new_cache = need_new_cache or cache_to_check.max_cache_len < max_cache_len
|
need_new_cache = need_new_cache or cache_to_check.max_cache_len < max_cache_len
|
||||||
@@ -1666,7 +1666,7 @@ class GenerationMixin:
|
|||||||
|
|
||||||
cache_kwargs = {
|
cache_kwargs = {
|
||||||
"config": self.config.get_text_config(),
|
"config": self.config.get_text_config(),
|
||||||
"batch_size": batch_size,
|
"max_batch_size": batch_size,
|
||||||
"max_cache_len": max_cache_len,
|
"max_cache_len": max_cache_len,
|
||||||
"device": device,
|
"device": device,
|
||||||
"dtype": cache_dtype,
|
"dtype": cache_dtype,
|
||||||
|
|||||||
@@ -1880,6 +1880,32 @@ class GenerationTesterMixin:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@parameterized.expand([("offloaded",)]) # ("offloaded_static",) TODO: @raushan fixme in some models (eg T5)
|
||||||
|
@require_torch_gpu
|
||||||
|
@pytest.mark.generate
|
||||||
|
def test_offloaded_cache_implementation(self, cache_implementation):
|
||||||
|
"""Tests we can generate by indicating `cache_implementation` for each possible cache class"""
|
||||||
|
for model_class in self.all_generative_model_classes:
|
||||||
|
if not model_class._supports_cache_class:
|
||||||
|
self.skipTest(reason="This model does not support the new cache format")
|
||||||
|
|
||||||
|
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
||||||
|
|
||||||
|
model = model_class(config).to(torch_device).eval()
|
||||||
|
generation_kwargs = {
|
||||||
|
"max_new_tokens": 5,
|
||||||
|
"use_cache": True,
|
||||||
|
"cache_implementation": cache_implementation,
|
||||||
|
}
|
||||||
|
|
||||||
|
legacy_results = model.generate(**generation_kwargs, **inputs_dict)
|
||||||
|
|
||||||
|
# Most cache classes have their own tests except for some that are tested here
|
||||||
|
# The ones here do not need special treatment when passing `cache_implementation`
|
||||||
|
# and are not bound to specific models only
|
||||||
|
new_results = model.generate(**generation_kwargs, **inputs_dict)
|
||||||
|
self.assertListEqual(legacy_results.tolist(), new_results.tolist())
|
||||||
|
|
||||||
@pytest.mark.generate
|
@pytest.mark.generate
|
||||||
def test_generate_with_static_cache(self):
|
def test_generate_with_static_cache(self):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -16,7 +16,9 @@
|
|||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
import pytest
|
||||||
import requests
|
import requests
|
||||||
|
from parameterized import parameterized
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoProcessor,
|
AutoProcessor,
|
||||||
@@ -365,6 +367,12 @@ class MllamaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTester
|
|||||||
def test_model_parallelism(self):
|
def test_model_parallelism(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@parameterized.expand([("offloaded",)])
|
||||||
|
@pytest.mark.generate
|
||||||
|
@unittest.skip(reason="Offloaded cache seems to not work with mllama's kv cache type")
|
||||||
|
def test_offloaded_cache_implementation(self, cache_implementation):
|
||||||
|
pass
|
||||||
|
|
||||||
def test_generate_text_only_with_cache(self):
|
def test_generate_text_only_with_cache(self):
|
||||||
"""
|
"""
|
||||||
Tests that our cached generation with text-only inputs works. When mllama was introduced, this feature
|
Tests that our cached generation with text-only inputs works. When mllama was introduced, this feature
|
||||||
|
|||||||
@@ -567,6 +567,12 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
|||||||
def test_generate_with_head_masking(self):
|
def test_generate_with_head_masking(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@parameterized.expand([("offloaded",)])
|
||||||
|
@pytest.mark.generate
|
||||||
|
@unittest.skip(reason="Whisper doesnt work with offloaded cache implementation yet")
|
||||||
|
def test_offloaded_cache_implementation(self, cache_implementation):
|
||||||
|
pass
|
||||||
|
|
||||||
@require_torch_fp16
|
@require_torch_fp16
|
||||||
def test_generate_fp16(self):
|
def test_generate_fp16(self):
|
||||||
config, input_dict = self.model_tester.prepare_config_and_inputs()
|
config, input_dict = self.model_tester.prepare_config_and_inputs()
|
||||||
|
|||||||
Reference in New Issue
Block a user