Cache: Static cache as a standalone object (#30476)
This commit is contained in:
@@ -362,3 +362,4 @@ A [`Constraint`] can be used to force the generation to include specific tokens
|
|||||||
[[autodoc]] StaticCache
|
[[autodoc]] StaticCache
|
||||||
- update
|
- update
|
||||||
- get_seq_length
|
- get_seq_length
|
||||||
|
- reorder_cache
|
||||||
|
|||||||
@@ -65,13 +65,12 @@ tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
|||||||
['The theory of special relativity states 1. The speed of light is constant in all inertial reference']
|
['The theory of special relativity states 1. The speed of light is constant in all inertial reference']
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Under the hood, `generate` will attempt to reuse the same cache object, removing the need for re-compilation at each call. However, if the batch size or the maximum output length increase between calls, the cache will have to be reinitialized, triggering a new compilation.
|
||||||
|
|
||||||
</hfoption>
|
</hfoption>
|
||||||
<hfoption id="setup_cache">
|
<hfoption id="Static Cache">
|
||||||
|
|
||||||
> [!WARNING]
|
A [`StaticCache`] object can be passed to the model's forward pass under the `past_key_values` argument, enabling the use of this object as a static kv-cache. Using this strategy, you can write your own function to decode the next token given the current token and position and cache position of previously generated tokens. You can also pass the [`StaticCache`] object to [`~GenerationMixin.generate`] and use it across calls, like you would do with a dynamic cache.
|
||||||
> The `_setup_cache` method is an internal and private method that is still under development. This means it may not be backward compatible and the API design may change in the future.
|
|
||||||
|
|
||||||
The `_setup_cache` method doesn't support [`~GenerationMixin.generate`] yet, so this method is a bit more involved. You'll need to write your own function to decode the next token given the current token and position and cache position of previously generated tokens.
|
|
||||||
|
|
||||||
```py
|
```py
|
||||||
from transformers import LlamaTokenizer, LlamaForCausalLM, StaticCache, logging
|
from transformers import LlamaTokenizer, LlamaForCausalLM, StaticCache, logging
|
||||||
@@ -90,17 +89,22 @@ tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", pad_token
|
|||||||
model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", device_map="sequential")
|
model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", device_map="sequential")
|
||||||
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)
|
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)
|
||||||
|
|
||||||
def decode_one_tokens(model, cur_token, input_pos, cache_position):
|
def decode_one_tokens(model, cur_token, input_pos, cache_position, past_key_values):
|
||||||
logits = model(
|
logits = model(
|
||||||
cur_token, position_ids=input_pos, cache_position=cache_position, return_dict=False, use_cache=True
|
cur_token,
|
||||||
|
position_ids=input_pos,
|
||||||
|
cache_position=cache_position,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
return_dict=False,
|
||||||
|
use_cache=True
|
||||||
)[0]
|
)[0]
|
||||||
new_token = torch.argmax(logits[:, -1], dim=-1)[:, None]
|
new_token = torch.argmax(logits[:, -1], dim=-1)[:, None]
|
||||||
return new_token
|
return new_token
|
||||||
```
|
```
|
||||||
|
|
||||||
There are a few important things you must do to enable static kv-cache and torch.compile with the `_setup_cache` method:
|
There are a few important things you must do to enable static kv-cache and torch.compile with the `StaticCache` method:
|
||||||
|
|
||||||
1. Access the model's `_setup_cache` method and pass it the [`StaticCache`] class. This is a more flexible method because it allows you to configure parameters like the maximum batch size and sequence length.
|
1. Initialize the [`StaticCache`] instance before using the model for inference. There you can configure parameters like the maximum batch size and sequence length.
|
||||||
|
|
||||||
2. Call torch.compile on the model to compile the forward pass with the static kv-cache.
|
2. Call torch.compile on the model to compile the forward pass with the static kv-cache.
|
||||||
|
|
||||||
@@ -109,24 +113,28 @@ There are a few important things you must do to enable static kv-cache and torch
|
|||||||
```py
|
```py
|
||||||
batch_size, seq_length = inputs["input_ids"].shape
|
batch_size, seq_length = inputs["input_ids"].shape
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
model._setup_cache(StaticCache, 2, max_cache_len=4096)
|
past_key_values = StaticCache(
|
||||||
cache_position = torch.arange(seq_length, device=torch_device)
|
config=model.config, max_batch_size=2, max_cache_len=4096, device=torch_device, dtype=model.dtype
|
||||||
generated_ids = torch.zeros(
|
)
|
||||||
batch_size, seq_length + NUM_TOKENS_TO_GENERATE + 1, dtype=torch.int, device=torch_device
|
cache_position = torch.arange(seq_length, device=torch_device)
|
||||||
)
|
generated_ids = torch.zeros(
|
||||||
generated_ids[:, cache_position] = inputs["input_ids"].to(torch_device).to(torch.int)
|
batch_size, seq_length + NUM_TOKENS_TO_GENERATE + 1, dtype=torch.int, device=torch_device
|
||||||
|
)
|
||||||
|
generated_ids[:, cache_position] = inputs["input_ids"].to(torch_device).to(torch.int)
|
||||||
|
|
||||||
logits = model(**inputs, cache_position=cache_position, return_dict=False, use_cache=True)[0]
|
logits = model(
|
||||||
next_token = torch.argmax(logits[:, -1], dim=-1)[:, None]
|
**inputs, cache_position=cache_position, past_key_values=past_key_values,return_dict=False, use_cache=True
|
||||||
generated_ids[:, seq_length] = next_token[:, 0]
|
)[0]
|
||||||
|
next_token = torch.argmax(logits[:, -1], dim=-1)[:, None]
|
||||||
|
generated_ids[:, seq_length] = next_token[:, 0]
|
||||||
|
|
||||||
decode_one_tokens = torch.compile(decode_one_tokens, mode="reduce-overhead", fullgraph=True)
|
decode_one_tokens = torch.compile(decode_one_tokens, mode="reduce-overhead", fullgraph=True)
|
||||||
cache_position = torch.tensor([seq_length + 1], device=torch_device)
|
cache_position = torch.tensor([seq_length + 1], device=torch_device)
|
||||||
for _ in range(1, NUM_TOKENS_TO_GENERATE):
|
for _ in range(1, NUM_TOKENS_TO_GENERATE):
|
||||||
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
|
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
|
||||||
next_token = decode_one_tokens(model, next_token.clone(), None, cache_position)
|
next_token = decode_one_tokens(model, next_token.clone(), None, cache_position, past_key_values)
|
||||||
generated_ids[:, cache_position] = next_token.int()
|
generated_ids[:, cache_position] = next_token.int()
|
||||||
cache_position += 1
|
cache_position += 1
|
||||||
|
|
||||||
text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||||||
text
|
text
|
||||||
@@ -134,6 +142,9 @@ text
|
|||||||
'My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p']
|
'My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p']
|
||||||
```
|
```
|
||||||
|
|
||||||
|
> [!TIP]
|
||||||
|
> If you want to reuse the [`StaticCache`] object on a new prompt, be sure to reset its contents with the `.reset()` method
|
||||||
|
|
||||||
</hfoption>
|
</hfoption>
|
||||||
</hfoptions>
|
</hfoptions>
|
||||||
|
|
||||||
|
|||||||
@@ -44,6 +44,7 @@ class Cache:
|
|||||||
|
|
||||||
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
||||||
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
|
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
|
||||||
|
# TODO: deprecate this function in favor of `cache_position`
|
||||||
raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.")
|
raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.")
|
||||||
|
|
||||||
def get_max_length(self) -> Optional[int]:
|
def get_max_length(self) -> Optional[int]:
|
||||||
@@ -61,6 +62,14 @@ class Cache:
|
|||||||
return max_length - new_seq_length
|
return max_length - new_seq_length
|
||||||
return previous_seq_length
|
return previous_seq_length
|
||||||
|
|
||||||
|
def reorder_cache(self, beam_idx: torch.LongTensor):
|
||||||
|
"""Reorders the cache for beam search, given the selected beam indices."""
|
||||||
|
for layer_idx in range(len(self.key_cache)):
|
||||||
|
device = self.key_cache[layer_idx].device
|
||||||
|
self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
|
||||||
|
device = self.value_cache[layer_idx].device
|
||||||
|
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def seen_tokens(self):
|
def seen_tokens(self):
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
@@ -150,6 +159,7 @@ class DynamicCache(Cache):
|
|||||||
|
|
||||||
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
||||||
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
|
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
|
||||||
|
# TODO: deprecate this function in favor of `cache_position`
|
||||||
if len(self.key_cache) <= layer_idx:
|
if len(self.key_cache) <= layer_idx:
|
||||||
return 0
|
return 0
|
||||||
return self.key_cache[layer_idx].shape[-2]
|
return self.key_cache[layer_idx].shape[-2]
|
||||||
@@ -158,14 +168,6 @@ class DynamicCache(Cache):
|
|||||||
"""Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
|
"""Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def reorder_cache(self, beam_idx: torch.LongTensor):
|
|
||||||
"""Reorders the cache for beam search, given the selected beam indices."""
|
|
||||||
for layer_idx in range(len(self.key_cache)):
|
|
||||||
device = self.key_cache[layer_idx].device
|
|
||||||
self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
|
|
||||||
device = self.value_cache[layer_idx].device
|
|
||||||
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
|
|
||||||
|
|
||||||
def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
|
def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
|
||||||
"""Converts the `DynamicCache` instance into the its equivalent in the legacy cache format."""
|
"""Converts the `DynamicCache` instance into the its equivalent in the legacy cache format."""
|
||||||
legacy_cache = ()
|
legacy_cache = ()
|
||||||
@@ -244,6 +246,7 @@ class SinkCache(Cache):
|
|||||||
|
|
||||||
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
||||||
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
|
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
|
||||||
|
# TODO: deprecate this function in favor of `cache_position`
|
||||||
# Workaround to make 'key_states.shape[-2] + past_key_value.get_seq_length(self.layer_idx)' <= window_length
|
# Workaround to make 'key_states.shape[-2] + past_key_value.get_seq_length(self.layer_idx)' <= window_length
|
||||||
if len(self.key_cache) <= layer_idx:
|
if len(self.key_cache) <= layer_idx:
|
||||||
return 0
|
return 0
|
||||||
@@ -332,14 +335,6 @@ class SinkCache(Cache):
|
|||||||
|
|
||||||
return self.key_cache[layer_idx], self.value_cache[layer_idx]
|
return self.key_cache[layer_idx], self.value_cache[layer_idx]
|
||||||
|
|
||||||
def reorder_cache(self, beam_idx: torch.LongTensor):
|
|
||||||
"""Reorders the cache for beam search, given the selected beam indices."""
|
|
||||||
for layer_idx in range(len(self.key_cache)):
|
|
||||||
device = self.key_cache[layer_idx].device
|
|
||||||
self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
|
|
||||||
device = self.value_cache[layer_idx].device
|
|
||||||
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
|
|
||||||
|
|
||||||
|
|
||||||
class StaticCache(Cache):
|
class StaticCache(Cache):
|
||||||
"""
|
"""
|
||||||
@@ -347,8 +342,7 @@ class StaticCache(Cache):
|
|||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
config (`PretrainedConfig):
|
config (`PretrainedConfig):
|
||||||
The configuration file defining the `max_position_embeddings`, `hidden_size` and `num_attention_heads`
|
The configuration file defining the shape-related attributes required to initialize the static cache.
|
||||||
required to initialize the static cache.
|
|
||||||
max_batch_size (`int`):
|
max_batch_size (`int`):
|
||||||
The maximum batch size with which the model will be used.
|
The maximum batch size with which the model will be used.
|
||||||
max_cache_len (`int`):
|
max_cache_len (`int`):
|
||||||
@@ -373,9 +367,18 @@ class StaticCache(Cache):
|
|||||||
config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
|
config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.key_cache: List[torch.Tensor] = []
|
||||||
|
self.value_cache: List[torch.Tensor] = []
|
||||||
cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)
|
cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)
|
||||||
self.key_cache: torch.Tensor = torch.zeros(cache_shape, dtype=self.dtype, device=device)
|
for _ in range(config.num_hidden_layers):
|
||||||
self.value_cache: torch.Tensor = torch.zeros(cache_shape, dtype=self.dtype, device=device)
|
# Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
|
||||||
|
# breaks when updating the cache.
|
||||||
|
new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
|
||||||
|
new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
|
||||||
|
torch._dynamo.mark_static_address(new_layer_key_cache)
|
||||||
|
torch._dynamo.mark_static_address(new_layer_value_cache)
|
||||||
|
self.key_cache.append(new_layer_key_cache)
|
||||||
|
self.value_cache.append(new_layer_value_cache)
|
||||||
|
|
||||||
def update(
|
def update(
|
||||||
self,
|
self,
|
||||||
@@ -394,42 +397,37 @@ class StaticCache(Cache):
|
|||||||
value_states (`torch.Tensor`):
|
value_states (`torch.Tensor`):
|
||||||
The new value states to cache.
|
The new value states to cache.
|
||||||
layer_idx (`int`):
|
layer_idx (`int`):
|
||||||
The index of the layer to cache the states for. Kept for backward compatibility
|
The index of the layer to cache the states for.
|
||||||
cache_kwargs (`Dict[str, Any]`, `optional`):
|
cache_kwargs (`Dict[str, Any]`, `optional`):
|
||||||
Additional arguments for the cache subclass. The `StaticCache` just needs the `q_len`
|
Additional arguments for the cache subclass. The `StaticCache` needs the `cache_position` input
|
||||||
to know how much of the cache it should overwrite.
|
to know how where to write in the cache.
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
A tuple containing the updated key and value states.
|
A tuple containing the updated key and value states.
|
||||||
"""
|
"""
|
||||||
new_cache_positions = cache_kwargs.get("cache_position")
|
cache_position = cache_kwargs.get("cache_position")
|
||||||
k_out = self.key_cache
|
k_out = self.key_cache[layer_idx]
|
||||||
v_out = self.value_cache
|
v_out = self.value_cache[layer_idx]
|
||||||
|
|
||||||
k_out[:, :, new_cache_positions] = key_states
|
k_out[:, :, cache_position] = key_states
|
||||||
v_out[:, :, new_cache_positions] = value_states
|
v_out[:, :, cache_position] = value_states
|
||||||
|
|
||||||
return k_out, v_out
|
return k_out, v_out
|
||||||
|
|
||||||
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
||||||
"""Returns the sequence length of the cached states that were seen by the model. `layer_idx` kept for BC"""
|
"""Returns the sequence length of the cached states that were seen by the model."""
|
||||||
# Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's
|
# Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's
|
||||||
# limit the check to the first batch member and head dimension.
|
# limit the check to the first batch member and head dimension.
|
||||||
# TODO: This is error prone, a filled cache may be `0.0`. Let's use a stateless integer instead, after
|
# TODO: deprecate this function in favor of `cache_position`
|
||||||
# https://github.com/pytorch/pytorch/issues/120248 is fixed
|
return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum()
|
||||||
return (self.key_cache[0, 0].any(dim=-1)).sum()
|
|
||||||
|
|
||||||
def get_max_length(self) -> Optional[int]:
|
def get_max_length(self) -> Optional[int]:
|
||||||
"""Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
|
"""Returns the maximum sequence length of the cached states."""
|
||||||
return self.max_cache_len
|
return self.max_cache_len
|
||||||
|
|
||||||
def reorder_cache(self, beam_idx: torch.LongTensor):
|
def reset(self):
|
||||||
"""Reorders the cache for beam search, given the selected beam indices."""
|
"""Resets the cache values while preserving the objects"""
|
||||||
device = self.key_cache.device
|
for layer_idx in range(len(self.key_cache)):
|
||||||
self.key_cache = self.key_cache.index_select(0, beam_idx.to(device))
|
# In-place ops prevent breaking the static address
|
||||||
device = self.value_cache.device
|
self.key_cache[layer_idx].zero_()
|
||||||
self.value_cache = self.value_cache.index_select(0, beam_idx.to(device))
|
self.value_cache[layer_idx].zero_()
|
||||||
|
|
||||||
def to_legacy_cache(self):
|
|
||||||
"""Dummy function for BC. We have to keep it because otherwise the call in the forward of models will break it"""
|
|
||||||
return None
|
|
||||||
|
|||||||
@@ -1310,6 +1310,34 @@ class GenerationMixin:
|
|||||||
model_kwargs["cache_position"] = torch.arange(past_length, cur_len, device=input_ids.device)
|
model_kwargs["cache_position"] = torch.arange(past_length, cur_len, device=input_ids.device)
|
||||||
return model_kwargs
|
return model_kwargs
|
||||||
|
|
||||||
|
def _get_static_cache(self, max_batch_size: int, max_cache_len: int) -> StaticCache:
|
||||||
|
"""
|
||||||
|
Sets a static cache for `generate`, that will persist across calls. A new cache will only be initialized a
|
||||||
|
new `generate` call requires a larger cache.
|
||||||
|
|
||||||
|
Returns the resulting static cache object.
|
||||||
|
"""
|
||||||
|
needs_new_cache = (
|
||||||
|
not hasattr(self, "_static_cache")
|
||||||
|
or self._static_cache.max_batch_size < max_batch_size
|
||||||
|
or self._static_cache.max_cache_len < max_cache_len
|
||||||
|
)
|
||||||
|
if needs_new_cache:
|
||||||
|
if hasattr(self.config, "_pre_quantization_dtype"):
|
||||||
|
cache_dtype = self.config._pre_quantization_dtype
|
||||||
|
else:
|
||||||
|
cache_dtype = self.dtype
|
||||||
|
self._static_cache = StaticCache(
|
||||||
|
config=self.config,
|
||||||
|
max_batch_size=max_batch_size,
|
||||||
|
max_cache_len=max_cache_len,
|
||||||
|
device=self.device,
|
||||||
|
dtype=cache_dtype,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self._static_cache.reset() # reset the cache for a new generation
|
||||||
|
return self._static_cache
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
@@ -1514,19 +1542,19 @@ class GenerationMixin:
|
|||||||
input_ids_length=input_ids_length,
|
input_ids_length=input_ids_length,
|
||||||
)
|
)
|
||||||
|
|
||||||
if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING:
|
if generation_config.cache_implementation is not None and model_kwargs.get("past_key_values") is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"Passing both `cache_implementation` (used to initialize certain caches) and `past_key_values` (a "
|
||||||
|
"Cache object) is unsupported. Please use only one of the two."
|
||||||
|
)
|
||||||
|
elif generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING:
|
||||||
|
if not self._supports_cache_class:
|
||||||
|
raise ValueError(
|
||||||
|
"This model does not support the `cache_implementation` argument. Please check the following "
|
||||||
|
"issue: https://github.com/huggingface/transformers/issues/28981."
|
||||||
|
)
|
||||||
if generation_config.cache_implementation == "static":
|
if generation_config.cache_implementation == "static":
|
||||||
if model_kwargs.get("past_key_values", False) is not False:
|
model_kwargs["past_key_values"] = self._get_static_cache(batch_size, generation_config.max_length)
|
||||||
raise ValueError(
|
|
||||||
"Using `past_key_values` argument with `generate()` when using a static KV cache is not supported. Please open an issue in Transformers GitHub repository."
|
|
||||||
)
|
|
||||||
cache_cls = NEED_SETUP_CACHE_CLASSES_MAPPING["static"]
|
|
||||||
if not callable(getattr(self, "_setup_cache", None)):
|
|
||||||
raise ValueError(
|
|
||||||
"The `generation_config` defines a `cache_implementation` that is not compatible with this model."
|
|
||||||
" Make sure it has a `_setup_cache` function."
|
|
||||||
)
|
|
||||||
self._setup_cache(cache_cls, max_batch_size=batch_size, max_cache_len=generation_config.max_length)
|
|
||||||
|
|
||||||
self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
|
self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
|
||||||
|
|
||||||
@@ -1844,14 +1872,6 @@ class GenerationMixin:
|
|||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING:
|
|
||||||
if not callable(getattr(self, "_reset_cache", None)):
|
|
||||||
raise ValueError(
|
|
||||||
"A `static_cache` was used to generate but there was a failure when trying to release the cache. "
|
|
||||||
" Make sure this model implements a `_reset_cache` function."
|
|
||||||
)
|
|
||||||
self._reset_cache()
|
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def _has_unfinished_sequences(self, this_peer_finished: bool, synced_gpus: bool, device: torch.device) -> bool:
|
def _has_unfinished_sequences(self, this_peer_finished: bool, synced_gpus: bool, device: torch.device) -> bool:
|
||||||
|
|||||||
@@ -340,6 +340,11 @@ class CohereFlashAttention2(CohereAttention):
|
|||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
if isinstance(past_key_value, StaticCache):
|
||||||
|
raise ValueError(
|
||||||
|
"`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
|
||||||
|
"make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
|
||||||
|
)
|
||||||
output_attentions = False
|
output_attentions = False
|
||||||
|
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
@@ -734,27 +739,6 @@ class CoherePreTrainedModel(PreTrainedModel):
|
|||||||
if module.padding_idx is not None:
|
if module.padding_idx is not None:
|
||||||
module.weight.data[module.padding_idx].zero_()
|
module.weight.data[module.padding_idx].zero_()
|
||||||
|
|
||||||
def _setup_cache(self, cache_cls, max_batch_size, max_cache_len: Optional[int] = None):
|
|
||||||
if self.config._attn_implementation == "flash_attention_2" and cache_cls == StaticCache:
|
|
||||||
raise ValueError(
|
|
||||||
"`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
|
|
||||||
"make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
|
|
||||||
)
|
|
||||||
|
|
||||||
for layer in self.model.layers:
|
|
||||||
device = layer.input_layernorm.weight.device
|
|
||||||
if hasattr(self.config, "_pre_quantization_dtype"):
|
|
||||||
dtype = self.config._pre_quantization_dtype
|
|
||||||
else:
|
|
||||||
dtype = layer.self_attn.o_proj.weight.dtype
|
|
||||||
layer.self_attn.past_key_value = cache_cls(
|
|
||||||
self.config, max_batch_size, max_cache_len, device=device, dtype=dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
def _reset_cache(self):
|
|
||||||
for layer in self.model.layers:
|
|
||||||
layer.self_attn.past_key_value = None
|
|
||||||
|
|
||||||
|
|
||||||
COHERE_INPUTS_DOCSTRING = r"""
|
COHERE_INPUTS_DOCSTRING = r"""
|
||||||
Args:
|
Args:
|
||||||
@@ -898,14 +882,11 @@ class CohereModel(CoherePreTrainedModel):
|
|||||||
inputs_embeds = self.embed_tokens(input_ids)
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
past_seen_tokens = 0
|
past_seen_tokens = 0
|
||||||
if use_cache: # kept for BC (cache positions)
|
if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
|
||||||
if not isinstance(past_key_values, StaticCache):
|
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
|
||||||
past_seen_tokens = past_key_values.get_seq_length()
|
|
||||||
|
|
||||||
if cache_position is None:
|
if cache_position is None:
|
||||||
if isinstance(past_key_values, StaticCache):
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||||
raise ValueError("cache_position is a required argument when using StaticCache.")
|
|
||||||
cache_position = torch.arange(
|
cache_position = torch.arange(
|
||||||
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
||||||
)
|
)
|
||||||
@@ -913,7 +894,7 @@ class CohereModel(CoherePreTrainedModel):
|
|||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
position_ids = cache_position.unsqueeze(0)
|
position_ids = cache_position.unsqueeze(0)
|
||||||
|
|
||||||
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_seen_tokens)
|
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values)
|
||||||
|
|
||||||
# embed positions
|
# embed positions
|
||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
@@ -982,7 +963,7 @@ class CohereModel(CoherePreTrainedModel):
|
|||||||
attention_mask: torch.Tensor,
|
attention_mask: torch.Tensor,
|
||||||
input_tensor: torch.Tensor,
|
input_tensor: torch.Tensor,
|
||||||
cache_position: torch.Tensor,
|
cache_position: torch.Tensor,
|
||||||
past_seen_tokens: int,
|
past_key_values: Cache,
|
||||||
):
|
):
|
||||||
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
|
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
|
||||||
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
|
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
|
||||||
@@ -994,9 +975,12 @@ class CohereModel(CoherePreTrainedModel):
|
|||||||
return attention_mask
|
return attention_mask
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if self.config._attn_implementation == "sdpa":
|
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||||||
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument,
|
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||||||
# in order to dispatch on Flash Attention 2.
|
# to infer the attention mask.
|
||||||
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||||
|
using_static_cache = isinstance(past_key_values, StaticCache)
|
||||||
|
if self.config._attn_implementation == "sdpa" and not using_static_cache:
|
||||||
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
||||||
attention_mask,
|
attention_mask,
|
||||||
inputs_embeds=input_tensor,
|
inputs_embeds=input_tensor,
|
||||||
@@ -1008,9 +992,9 @@ class CohereModel(CoherePreTrainedModel):
|
|||||||
dtype, device = input_tensor.dtype, input_tensor.device
|
dtype, device = input_tensor.dtype, input_tensor.device
|
||||||
min_dtype = torch.finfo(dtype).min
|
min_dtype = torch.finfo(dtype).min
|
||||||
sequence_length = input_tensor.shape[1]
|
sequence_length = input_tensor.shape[1]
|
||||||
if hasattr(getattr(self.layers[0], "self_attn", {}), "past_key_value"): # static cache
|
if using_static_cache:
|
||||||
target_length = self.config.max_position_embeddings
|
target_length = past_key_values.get_max_length()
|
||||||
else: # dynamic cache
|
else:
|
||||||
target_length = (
|
target_length = (
|
||||||
attention_mask.shape[-1]
|
attention_mask.shape[-1]
|
||||||
if isinstance(attention_mask, torch.Tensor)
|
if isinstance(attention_mask, torch.Tensor)
|
||||||
@@ -1032,6 +1016,10 @@ class CohereModel(CoherePreTrainedModel):
|
|||||||
# backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
|
# backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
|
||||||
# cache. In that case, the 4D attention mask attends to the newest tokens only.
|
# cache. In that case, the 4D attention mask attends to the newest tokens only.
|
||||||
if attention_mask.shape[-2] < cache_position[0] + sequence_length:
|
if attention_mask.shape[-2] < cache_position[0] + sequence_length:
|
||||||
|
logger.warning_once(
|
||||||
|
"Passing a 4d mask shorter than the input length is deprecated and will be removed in "
|
||||||
|
"transformers v4.42.0"
|
||||||
|
)
|
||||||
offset = cache_position[0]
|
offset = cache_position[0]
|
||||||
else:
|
else:
|
||||||
offset = 0
|
offset = 0
|
||||||
@@ -1189,13 +1177,6 @@ class CohereForCausalLM(CoherePreTrainedModel):
|
|||||||
use_cache=True,
|
use_cache=True,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
# With static cache, the `past_key_values` is None
|
|
||||||
# TODO joao: standardize interface for the different Cache classes and remove of this if
|
|
||||||
has_static_cache = False
|
|
||||||
if past_key_values is None:
|
|
||||||
past_key_values = getattr(getattr(self.model.layers[0], "self_attn", {}), "past_key_value", None)
|
|
||||||
has_static_cache = past_key_values is not None
|
|
||||||
|
|
||||||
past_length = 0
|
past_length = 0
|
||||||
if past_key_values is not None:
|
if past_key_values is not None:
|
||||||
if isinstance(past_key_values, Cache):
|
if isinstance(past_key_values, Cache):
|
||||||
@@ -1213,8 +1194,7 @@ class CohereForCausalLM(CoherePreTrainedModel):
|
|||||||
|
|
||||||
# Keep only the unprocessed tokens:
|
# Keep only the unprocessed tokens:
|
||||||
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
||||||
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
|
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as input)
|
||||||
# input)
|
|
||||||
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
|
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
|
||||||
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
||||||
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
|
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
|
||||||
@@ -1254,9 +1234,6 @@ class CohereForCausalLM(CoherePreTrainedModel):
|
|||||||
elif use_cache:
|
elif use_cache:
|
||||||
cache_position = cache_position[-input_length:]
|
cache_position = cache_position[-input_length:]
|
||||||
|
|
||||||
if has_static_cache:
|
|
||||||
past_key_values = None
|
|
||||||
|
|
||||||
model_inputs.update(
|
model_inputs.update(
|
||||||
{
|
{
|
||||||
"position_ids": position_ids,
|
"position_ids": position_ids,
|
||||||
|
|||||||
@@ -15,7 +15,7 @@
|
|||||||
""" PyTorch DBRX model. """
|
""" PyTorch DBRX model. """
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import Any, Dict, Optional, Tuple, Union
|
from typing import Any, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@@ -354,6 +354,11 @@ class DbrxFlashAttention2(DbrxAttention):
|
|||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
if isinstance(past_key_value, StaticCache):
|
||||||
|
raise ValueError(
|
||||||
|
"`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
|
||||||
|
"make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
|
||||||
|
)
|
||||||
logger.info("Implicitly setting `output_attentions` to False as it is not supported in Flash Attention.")
|
logger.info("Implicitly setting `output_attentions` to False as it is not supported in Flash Attention.")
|
||||||
output_attentions = False
|
output_attentions = False
|
||||||
|
|
||||||
@@ -622,6 +627,7 @@ class DbrxSdpaAttention(DbrxAttention):
|
|||||||
value_states,
|
value_states,
|
||||||
attn_mask=causal_mask,
|
attn_mask=causal_mask,
|
||||||
dropout_p=self.attn_pdrop if self.training else 0.0,
|
dropout_p=self.attn_pdrop if self.training else 0.0,
|
||||||
|
is_causal=causal_mask is None and q_len > 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
@@ -957,28 +963,6 @@ class DbrxPreTrainedModel(PreTrainedModel):
|
|||||||
module.v1.data.normal_(mean=0.0, std=std)
|
module.v1.data.normal_(mean=0.0, std=std)
|
||||||
module.w2.data.normal_(mean=0.0, std=std)
|
module.w2.data.normal_(mean=0.0, std=std)
|
||||||
|
|
||||||
def _setup_cache(self, cache_cls: Any, max_batch_size: int, max_cache_len: int):
|
|
||||||
if self.config._attn_implementation == "flash_attention_2" and cache_cls == StaticCache:
|
|
||||||
raise ValueError(
|
|
||||||
"`static` cache implementation is not compatible with "
|
|
||||||
+ "`attn_implementation==flash_attention_2`. Make sure to use "
|
|
||||||
+ "`spda` in the mean time and open an issue at https://github.com/huggingface/transformers."
|
|
||||||
)
|
|
||||||
|
|
||||||
for block in self.transformer.blocks:
|
|
||||||
device = block.norm_attn_norm.norm_1.weight.device
|
|
||||||
if hasattr(self.config, "_pre_quantization_dtype"):
|
|
||||||
dtype = self.config._pre_quantization_dtype
|
|
||||||
else:
|
|
||||||
dtype = block.norm_attn_norm.attn.out_proj.weight.dtype
|
|
||||||
block.norm_attn_norm.attn.past_key_value = cache_cls(
|
|
||||||
self.config, max_batch_size, max_cache_len, device=device, dtype=dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
def _reset_cache(self):
|
|
||||||
for block in self.transformer.blocks:
|
|
||||||
block.norm_attn_norm.attn.past_key_value = None
|
|
||||||
|
|
||||||
|
|
||||||
DBRX_INPUTS_DOCSTRING = r"""
|
DBRX_INPUTS_DOCSTRING = r"""
|
||||||
Args:
|
Args:
|
||||||
@@ -1131,22 +1115,18 @@ class DbrxModel(DbrxPreTrainedModel):
|
|||||||
|
|
||||||
inputs_embeds = nn.functional.dropout(inputs_embeds, p=self.emb_pdrop, training=self.training)
|
inputs_embeds = nn.functional.dropout(inputs_embeds, p=self.emb_pdrop, training=self.training)
|
||||||
|
|
||||||
past_seen_tokens = 0
|
if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
|
||||||
if use_cache: # kept for BC (cache positions)
|
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||||
if not isinstance(past_key_values, StaticCache):
|
|
||||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
|
||||||
past_seen_tokens = past_key_values.get_seq_length()
|
|
||||||
|
|
||||||
if cache_position is None:
|
if cache_position is None:
|
||||||
if isinstance(past_key_values, StaticCache):
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||||
raise ValueError("cache_position is a required argument when using StaticCache.")
|
|
||||||
cache_position = torch.arange(
|
cache_position = torch.arange(
|
||||||
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
||||||
)
|
)
|
||||||
|
|
||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
position_ids = cache_position.unsqueeze(0)
|
position_ids = cache_position.unsqueeze(0)
|
||||||
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
|
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values)
|
||||||
|
|
||||||
# embed positions
|
# embed positions
|
||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
@@ -1205,7 +1185,9 @@ class DbrxModel(DbrxPreTrainedModel):
|
|||||||
next_cache = None
|
next_cache = None
|
||||||
if use_cache:
|
if use_cache:
|
||||||
next_cache = (
|
next_cache = (
|
||||||
next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache
|
next_decoder_cache.to_legacy_cache()
|
||||||
|
if isinstance(next_decoder_cache, DynamicCache)
|
||||||
|
else next_decoder_cache
|
||||||
)
|
)
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return tuple(
|
return tuple(
|
||||||
@@ -1221,28 +1203,45 @@ class DbrxModel(DbrxPreTrainedModel):
|
|||||||
router_logits=all_router_logits,
|
router_logits=all_router_logits,
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
|
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
|
||||||
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
|
|
||||||
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
|
|
||||||
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
|
|
||||||
def _update_causal_mask(
|
def _update_causal_mask(
|
||||||
self, attention_mask: Optional[torch.Tensor], input_tensor: torch.Tensor, cache_position: torch.Tensor
|
self,
|
||||||
) -> Optional[torch.Tensor]:
|
attention_mask: torch.Tensor,
|
||||||
|
input_tensor: torch.Tensor,
|
||||||
|
cache_position: torch.Tensor,
|
||||||
|
past_key_values: Cache,
|
||||||
|
):
|
||||||
|
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
|
||||||
|
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
|
||||||
|
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
|
||||||
|
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
|
||||||
if self.config._attn_implementation == "flash_attention_2":
|
if self.config._attn_implementation == "flash_attention_2":
|
||||||
if attention_mask is not None and 0.0 in attention_mask:
|
if attention_mask is not None and 0.0 in attention_mask:
|
||||||
return attention_mask
|
return attention_mask
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||||||
|
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||||||
|
# to infer the attention mask.
|
||||||
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||||
|
using_static_cache = isinstance(past_key_values, StaticCache)
|
||||||
|
if self.config._attn_implementation == "sdpa" and not using_static_cache:
|
||||||
|
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
||||||
|
attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens
|
||||||
|
):
|
||||||
|
return None
|
||||||
|
|
||||||
dtype, device = input_tensor.dtype, input_tensor.device
|
dtype, device = input_tensor.dtype, input_tensor.device
|
||||||
min_dtype = torch.finfo(dtype).min
|
min_dtype = torch.finfo(dtype).min
|
||||||
sequence_length = input_tensor.shape[1]
|
sequence_length = input_tensor.shape[1]
|
||||||
if hasattr(self.blocks[0].norm_attn_norm.attn, "past_key_value"): # static cache
|
if using_static_cache:
|
||||||
target_length = self.config.max_position_embeddings
|
target_length = past_key_values.get_max_length()
|
||||||
else: # dynamic cache
|
else:
|
||||||
target_length = (
|
target_length = (
|
||||||
attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else cache_position[-1] + 1
|
attention_mask.shape[-1]
|
||||||
|
if isinstance(attention_mask, torch.Tensor)
|
||||||
|
else past_seen_tokens + sequence_length + 1
|
||||||
)
|
)
|
||||||
target_length = int(target_length)
|
|
||||||
|
|
||||||
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
|
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
|
||||||
if sequence_length != 1:
|
if sequence_length != 1:
|
||||||
@@ -1259,6 +1258,10 @@ class DbrxModel(DbrxPreTrainedModel):
|
|||||||
# backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
|
# backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
|
||||||
# cache. In that case, the 4D attention mask attends to the newest tokens only.
|
# cache. In that case, the 4D attention mask attends to the newest tokens only.
|
||||||
if attention_mask.shape[-2] < cache_position[0] + sequence_length:
|
if attention_mask.shape[-2] < cache_position[0] + sequence_length:
|
||||||
|
logger.warning_once(
|
||||||
|
"Passing a 4d mask shorter than the input length is deprecated and will be removed in "
|
||||||
|
"transformers v4.42.0"
|
||||||
|
)
|
||||||
offset = cache_position[0]
|
offset = cache_position[0]
|
||||||
else:
|
else:
|
||||||
offset = 0
|
offset = 0
|
||||||
@@ -1273,17 +1276,10 @@ class DbrxModel(DbrxPreTrainedModel):
|
|||||||
and attention_mask is not None
|
and attention_mask is not None
|
||||||
and attention_mask.device.type == "cuda"
|
and attention_mask.device.type == "cuda"
|
||||||
):
|
):
|
||||||
# TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
|
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
||||||
is_tracing = (
|
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
||||||
torch.jit.is_tracing()
|
# Details: https://github.com/pytorch/pytorch/issues/110213
|
||||||
or isinstance(input_tensor, torch.fx.Proxy)
|
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
||||||
or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
|
|
||||||
)
|
|
||||||
if not is_tracing and torch.any(attention_mask != 1):
|
|
||||||
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
|
||||||
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
|
||||||
# Details: https://github.com/pytorch/pytorch/issues/110213
|
|
||||||
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
|
||||||
|
|
||||||
return causal_mask
|
return causal_mask
|
||||||
|
|
||||||
@@ -1431,28 +1427,35 @@ class DbrxForCausalLM(DbrxPreTrainedModel):
|
|||||||
router_logits=outputs.router_logits,
|
router_logits=outputs.router_logits,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation
|
||||||
def prepare_inputs_for_generation(
|
def prepare_inputs_for_generation(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids,
|
||||||
past_key_values: Optional[Cache] = None,
|
past_key_values=None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask=None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds=None,
|
||||||
**kwargs: Any,
|
cache_position=None,
|
||||||
) -> Dict[str, Any]:
|
use_cache=True,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
past_length = 0
|
past_length = 0
|
||||||
if past_key_values is not None:
|
if past_key_values is not None:
|
||||||
if isinstance(past_key_values, Cache):
|
if isinstance(past_key_values, Cache):
|
||||||
cache_length = past_key_values.get_seq_length()
|
past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
|
||||||
past_length = past_key_values.seen_tokens
|
max_cache_length = (
|
||||||
max_cache_length = past_key_values.get_max_length()
|
torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
|
||||||
|
if past_key_values.get_max_length() is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)
|
||||||
|
# TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
|
||||||
else:
|
else:
|
||||||
cache_length = past_length = past_key_values[0][0].shape[2]
|
cache_length = past_length = past_key_values[0][0].shape[2]
|
||||||
max_cache_length = None
|
max_cache_length = None
|
||||||
|
|
||||||
# Keep only the unprocessed tokens:
|
# Keep only the unprocessed tokens:
|
||||||
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
||||||
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
|
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as input)
|
||||||
# input)
|
|
||||||
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
|
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
|
||||||
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
||||||
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
|
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
|
||||||
@@ -1477,22 +1480,6 @@ class DbrxForCausalLM(DbrxPreTrainedModel):
|
|||||||
if past_key_values:
|
if past_key_values:
|
||||||
position_ids = position_ids[:, -input_ids.shape[1] :]
|
position_ids = position_ids[:, -input_ids.shape[1] :]
|
||||||
|
|
||||||
if self.generation_config.cache_implementation == "static":
|
|
||||||
# generation with static cache
|
|
||||||
cache_position = kwargs.get("cache_position", None)
|
|
||||||
if cache_position is None:
|
|
||||||
past_length = 0
|
|
||||||
else:
|
|
||||||
past_length = cache_position[-1] + 1
|
|
||||||
input_ids = input_ids[:, past_length:]
|
|
||||||
position_ids = position_ids[:, past_length:] if position_ids is not None else None
|
|
||||||
|
|
||||||
# TODO @gante we should only keep a `cache_position` in generate, and do +=1.
|
|
||||||
# same goes for position ids. Could also help with continued generation.
|
|
||||||
input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
|
|
||||||
cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device)
|
|
||||||
position_ids = position_ids.contiguous() if position_ids is not None else None
|
|
||||||
|
|
||||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||||
if inputs_embeds is not None and past_key_values is None:
|
if inputs_embeds is not None and past_key_values is None:
|
||||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||||
@@ -1502,12 +1489,18 @@ class DbrxForCausalLM(DbrxPreTrainedModel):
|
|||||||
# TODO: use `next_tokens` directly instead.
|
# TODO: use `next_tokens` directly instead.
|
||||||
model_inputs = {"input_ids": input_ids.contiguous()}
|
model_inputs = {"input_ids": input_ids.contiguous()}
|
||||||
|
|
||||||
|
input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
|
||||||
|
if cache_position is None:
|
||||||
|
cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device)
|
||||||
|
elif use_cache:
|
||||||
|
cache_position = cache_position[-input_length:]
|
||||||
|
|
||||||
model_inputs.update(
|
model_inputs.update(
|
||||||
{
|
{
|
||||||
"position_ids": position_ids,
|
"position_ids": position_ids,
|
||||||
"cache_position": cache_position,
|
"cache_position": cache_position,
|
||||||
"past_key_values": past_key_values,
|
"past_key_values": past_key_values,
|
||||||
"use_cache": kwargs.get("use_cache"),
|
"use_cache": use_cache,
|
||||||
"attention_mask": attention_mask,
|
"attention_mask": attention_mask,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -332,6 +332,11 @@ class GemmaFlashAttention2(GemmaAttention):
|
|||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
if isinstance(past_key_value, StaticCache):
|
||||||
|
raise ValueError(
|
||||||
|
"`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
|
||||||
|
"make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
|
||||||
|
)
|
||||||
output_attentions = False
|
output_attentions = False
|
||||||
|
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
@@ -615,7 +620,7 @@ class GemmaDecoderLayer(nn.Module):
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
past_key_value: Optional[Cache] = None,
|
||||||
output_attentions: Optional[bool] = False,
|
output_attentions: Optional[bool] = False,
|
||||||
use_cache: Optional[bool] = False,
|
use_cache: Optional[bool] = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
@@ -717,23 +722,6 @@ class GemmaPreTrainedModel(PreTrainedModel):
|
|||||||
if module.padding_idx is not None:
|
if module.padding_idx is not None:
|
||||||
module.weight.data[module.padding_idx].zero_()
|
module.weight.data[module.padding_idx].zero_()
|
||||||
|
|
||||||
def _setup_cache(self, cache_cls, max_batch_size, max_cache_len: Optional[int] = None):
|
|
||||||
if self.config._attn_implementation == "flash_attention_2" and cache_cls == StaticCache:
|
|
||||||
raise ValueError(
|
|
||||||
"`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
|
|
||||||
"make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
|
|
||||||
)
|
|
||||||
|
|
||||||
for layer in self.model.layers:
|
|
||||||
weights = layer.self_attn.o_proj.weight
|
|
||||||
layer.self_attn.past_key_value = cache_cls(
|
|
||||||
self.config, max_batch_size, max_cache_len, device=weights.device, dtype=weights.dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
def _reset_cache(self):
|
|
||||||
for layer in self.model.layers:
|
|
||||||
layer.self_attn.past_key_value = None
|
|
||||||
|
|
||||||
|
|
||||||
GEMMA_INPUTS_DOCSTRING = r"""
|
GEMMA_INPUTS_DOCSTRING = r"""
|
||||||
Args:
|
Args:
|
||||||
@@ -850,7 +838,7 @@ class GemmaModel(GemmaPreTrainedModel):
|
|||||||
input_ids: torch.LongTensor = None,
|
input_ids: torch.LongTensor = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
use_cache: Optional[bool] = None,
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
@@ -879,13 +867,11 @@ class GemmaModel(GemmaPreTrainedModel):
|
|||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
inputs_embeds = self.embed_tokens(input_ids)
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
past_seen_tokens = 0
|
if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
|
||||||
if use_cache: # kept for BC (cache positions)
|
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||||
if not isinstance(past_key_values, StaticCache):
|
|
||||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
|
||||||
past_seen_tokens = past_key_values.get_seq_length()
|
|
||||||
|
|
||||||
if cache_position is None:
|
if cache_position is None:
|
||||||
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||||
cache_position = torch.arange(
|
cache_position = torch.arange(
|
||||||
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
||||||
)
|
)
|
||||||
@@ -893,7 +879,7 @@ class GemmaModel(GemmaPreTrainedModel):
|
|||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
position_ids = cache_position.unsqueeze(0)
|
position_ids = cache_position.unsqueeze(0)
|
||||||
|
|
||||||
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_seen_tokens)
|
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values)
|
||||||
|
|
||||||
# embed positions
|
# embed positions
|
||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
@@ -952,7 +938,9 @@ class GemmaModel(GemmaPreTrainedModel):
|
|||||||
next_cache = None
|
next_cache = None
|
||||||
if use_cache:
|
if use_cache:
|
||||||
next_cache = (
|
next_cache = (
|
||||||
next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache
|
next_decoder_cache.to_legacy_cache()
|
||||||
|
if isinstance(next_decoder_cache, DynamicCache)
|
||||||
|
else next_decoder_cache
|
||||||
)
|
)
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
||||||
@@ -968,7 +956,7 @@ class GemmaModel(GemmaPreTrainedModel):
|
|||||||
attention_mask: torch.Tensor,
|
attention_mask: torch.Tensor,
|
||||||
input_tensor: torch.Tensor,
|
input_tensor: torch.Tensor,
|
||||||
cache_position: torch.Tensor,
|
cache_position: torch.Tensor,
|
||||||
past_seen_tokens: int,
|
past_key_values: Cache,
|
||||||
):
|
):
|
||||||
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
|
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
|
||||||
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
|
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
|
||||||
@@ -980,9 +968,12 @@ class GemmaModel(GemmaPreTrainedModel):
|
|||||||
return attention_mask
|
return attention_mask
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if self.config._attn_implementation == "sdpa":
|
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||||||
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument,
|
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||||||
# in order to dispatch on Flash Attention 2.
|
# to infer the attention mask.
|
||||||
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||||
|
using_static_cache = isinstance(past_key_values, StaticCache)
|
||||||
|
if self.config._attn_implementation == "sdpa" and not using_static_cache:
|
||||||
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
||||||
attention_mask,
|
attention_mask,
|
||||||
inputs_embeds=input_tensor,
|
inputs_embeds=input_tensor,
|
||||||
@@ -994,9 +985,9 @@ class GemmaModel(GemmaPreTrainedModel):
|
|||||||
dtype, device = input_tensor.dtype, input_tensor.device
|
dtype, device = input_tensor.dtype, input_tensor.device
|
||||||
min_dtype = torch.finfo(dtype).min
|
min_dtype = torch.finfo(dtype).min
|
||||||
sequence_length = input_tensor.shape[1]
|
sequence_length = input_tensor.shape[1]
|
||||||
if hasattr(getattr(self.layers[0], "self_attn", {}), "past_key_value"): # static cache
|
if using_static_cache:
|
||||||
target_length = self.config.max_position_embeddings
|
target_length = past_key_values.get_max_length()
|
||||||
else: # dynamic cache
|
else:
|
||||||
target_length = (
|
target_length = (
|
||||||
attention_mask.shape[-1]
|
attention_mask.shape[-1]
|
||||||
if isinstance(attention_mask, torch.Tensor)
|
if isinstance(attention_mask, torch.Tensor)
|
||||||
@@ -1018,6 +1009,10 @@ class GemmaModel(GemmaPreTrainedModel):
|
|||||||
# backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
|
# backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
|
||||||
# cache. In that case, the 4D attention mask attends to the newest tokens only.
|
# cache. In that case, the 4D attention mask attends to the newest tokens only.
|
||||||
if attention_mask.shape[-2] < cache_position[0] + sequence_length:
|
if attention_mask.shape[-2] < cache_position[0] + sequence_length:
|
||||||
|
logger.warning_once(
|
||||||
|
"Passing a 4d mask shorter than the input length is deprecated and will be removed in "
|
||||||
|
"transformers v4.42.0"
|
||||||
|
)
|
||||||
offset = cache_position[0]
|
offset = cache_position[0]
|
||||||
else:
|
else:
|
||||||
offset = 0
|
offset = 0
|
||||||
@@ -1079,7 +1074,7 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
|
|||||||
input_ids: torch.LongTensor = None,
|
input_ids: torch.LongTensor = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
labels: Optional[torch.LongTensor] = None,
|
labels: Optional[torch.LongTensor] = None,
|
||||||
use_cache: Optional[bool] = None,
|
use_cache: Optional[bool] = None,
|
||||||
@@ -1171,13 +1166,6 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
|
|||||||
use_cache=True,
|
use_cache=True,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
# With static cache, the `past_key_values` is None
|
|
||||||
# TODO joao: standardize interface for the different Cache classes and remove of this if
|
|
||||||
has_static_cache = False
|
|
||||||
if past_key_values is None:
|
|
||||||
past_key_values = getattr(getattr(self.model.layers[0], "self_attn", {}), "past_key_value", None)
|
|
||||||
has_static_cache = past_key_values is not None
|
|
||||||
|
|
||||||
past_length = 0
|
past_length = 0
|
||||||
if past_key_values is not None:
|
if past_key_values is not None:
|
||||||
if isinstance(past_key_values, Cache):
|
if isinstance(past_key_values, Cache):
|
||||||
@@ -1195,8 +1183,7 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
|
|||||||
|
|
||||||
# Keep only the unprocessed tokens:
|
# Keep only the unprocessed tokens:
|
||||||
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
||||||
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
|
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as input)
|
||||||
# input)
|
|
||||||
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
|
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
|
||||||
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
||||||
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
|
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
|
||||||
@@ -1236,9 +1223,6 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
|
|||||||
elif use_cache:
|
elif use_cache:
|
||||||
cache_position = cache_position[-input_length:]
|
cache_position = cache_position[-input_length:]
|
||||||
|
|
||||||
if has_static_cache:
|
|
||||||
past_key_values = None
|
|
||||||
|
|
||||||
model_inputs.update(
|
model_inputs.update(
|
||||||
{
|
{
|
||||||
"position_ids": position_ids,
|
"position_ids": position_ids,
|
||||||
@@ -1298,7 +1282,7 @@ class GemmaForSequenceClassification(GemmaPreTrainedModel):
|
|||||||
input_ids: torch.LongTensor = None,
|
input_ids: torch.LongTensor = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
labels: Optional[torch.LongTensor] = None,
|
labels: Optional[torch.LongTensor] = None,
|
||||||
use_cache: Optional[bool] = None,
|
use_cache: Optional[bool] = None,
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ from torch import nn
|
|||||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||||
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...cache_utils import DynamicCache # we need __iter__ and __len__ of pkv
|
from ...cache_utils import Cache, DynamicCache # we need __iter__ and __len__ of pkv
|
||||||
from ...modeling_attn_mask_utils import (
|
from ...modeling_attn_mask_utils import (
|
||||||
AttentionMaskConverter,
|
AttentionMaskConverter,
|
||||||
)
|
)
|
||||||
@@ -1807,7 +1807,7 @@ class JambaForSequenceClassification(JambaPreTrainedModel):
|
|||||||
input_ids: torch.LongTensor = None,
|
input_ids: torch.LongTensor = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
labels: Optional[torch.LongTensor] = None,
|
labels: Optional[torch.LongTensor] = None,
|
||||||
use_cache: Optional[bool] = None,
|
use_cache: Optional[bool] = None,
|
||||||
|
|||||||
@@ -428,6 +428,12 @@ class LlamaFlashAttention2(LlamaAttention):
|
|||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
if isinstance(past_key_value, StaticCache):
|
||||||
|
raise ValueError(
|
||||||
|
"`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
|
||||||
|
"make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
|
||||||
|
)
|
||||||
|
|
||||||
output_attentions = False
|
output_attentions = False
|
||||||
|
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
@@ -710,7 +716,7 @@ class LlamaDecoderLayer(nn.Module):
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
past_key_value: Optional[Cache] = None,
|
||||||
output_attentions: Optional[bool] = False,
|
output_attentions: Optional[bool] = False,
|
||||||
use_cache: Optional[bool] = False,
|
use_cache: Optional[bool] = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
@@ -811,27 +817,6 @@ class LlamaPreTrainedModel(PreTrainedModel):
|
|||||||
if module.padding_idx is not None:
|
if module.padding_idx is not None:
|
||||||
module.weight.data[module.padding_idx].zero_()
|
module.weight.data[module.padding_idx].zero_()
|
||||||
|
|
||||||
def _setup_cache(self, cache_cls, max_batch_size, max_cache_len: Optional[int] = None):
|
|
||||||
if self.config._attn_implementation == "flash_attention_2" and cache_cls == StaticCache:
|
|
||||||
raise ValueError(
|
|
||||||
"`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
|
|
||||||
"make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
|
|
||||||
)
|
|
||||||
|
|
||||||
for layer in self.model.layers:
|
|
||||||
device = layer.input_layernorm.weight.device
|
|
||||||
if hasattr(self.config, "_pre_quantization_dtype"):
|
|
||||||
dtype = self.config._pre_quantization_dtype
|
|
||||||
else:
|
|
||||||
dtype = layer.self_attn.o_proj.weight.dtype
|
|
||||||
layer.self_attn.past_key_value = cache_cls(
|
|
||||||
self.config, max_batch_size, max_cache_len, device=device, dtype=dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
def _reset_cache(self):
|
|
||||||
for layer in self.model.layers:
|
|
||||||
layer.self_attn.past_key_value = None
|
|
||||||
|
|
||||||
|
|
||||||
LLAMA_INPUTS_DOCSTRING = r"""
|
LLAMA_INPUTS_DOCSTRING = r"""
|
||||||
Args:
|
Args:
|
||||||
@@ -946,7 +931,7 @@ class LlamaModel(LlamaPreTrainedModel):
|
|||||||
input_ids: torch.LongTensor = None,
|
input_ids: torch.LongTensor = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
use_cache: Optional[bool] = None,
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
@@ -975,23 +960,18 @@ class LlamaModel(LlamaPreTrainedModel):
|
|||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
inputs_embeds = self.embed_tokens(input_ids)
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
past_seen_tokens = 0
|
if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
|
||||||
if use_cache: # kept for BC (cache positions)
|
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||||
if not isinstance(past_key_values, StaticCache):
|
|
||||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
|
||||||
past_seen_tokens = past_key_values.get_seq_length()
|
|
||||||
|
|
||||||
if cache_position is None:
|
if cache_position is None:
|
||||||
if isinstance(past_key_values, StaticCache):
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||||
raise ValueError("cache_position is a required argument when using StaticCache.")
|
|
||||||
cache_position = torch.arange(
|
cache_position = torch.arange(
|
||||||
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
||||||
)
|
)
|
||||||
|
|
||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
position_ids = cache_position.unsqueeze(0)
|
position_ids = cache_position.unsqueeze(0)
|
||||||
|
|
||||||
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_seen_tokens)
|
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values)
|
||||||
|
|
||||||
# embed positions
|
# embed positions
|
||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
@@ -1044,7 +1024,9 @@ class LlamaModel(LlamaPreTrainedModel):
|
|||||||
next_cache = None
|
next_cache = None
|
||||||
if use_cache:
|
if use_cache:
|
||||||
next_cache = (
|
next_cache = (
|
||||||
next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache
|
next_decoder_cache.to_legacy_cache()
|
||||||
|
if isinstance(next_decoder_cache, DynamicCache)
|
||||||
|
else next_decoder_cache
|
||||||
)
|
)
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
||||||
@@ -1060,7 +1042,7 @@ class LlamaModel(LlamaPreTrainedModel):
|
|||||||
attention_mask: torch.Tensor,
|
attention_mask: torch.Tensor,
|
||||||
input_tensor: torch.Tensor,
|
input_tensor: torch.Tensor,
|
||||||
cache_position: torch.Tensor,
|
cache_position: torch.Tensor,
|
||||||
past_seen_tokens: int,
|
past_key_values: Cache,
|
||||||
):
|
):
|
||||||
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
|
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
|
||||||
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
|
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
|
||||||
@@ -1072,9 +1054,12 @@ class LlamaModel(LlamaPreTrainedModel):
|
|||||||
return attention_mask
|
return attention_mask
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if self.config._attn_implementation == "sdpa":
|
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||||||
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument,
|
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||||||
# in order to dispatch on Flash Attention 2.
|
# to infer the attention mask.
|
||||||
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||||
|
using_static_cache = isinstance(past_key_values, StaticCache)
|
||||||
|
if self.config._attn_implementation == "sdpa" and not using_static_cache:
|
||||||
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
||||||
attention_mask,
|
attention_mask,
|
||||||
inputs_embeds=input_tensor,
|
inputs_embeds=input_tensor,
|
||||||
@@ -1086,9 +1071,9 @@ class LlamaModel(LlamaPreTrainedModel):
|
|||||||
dtype, device = input_tensor.dtype, input_tensor.device
|
dtype, device = input_tensor.dtype, input_tensor.device
|
||||||
min_dtype = torch.finfo(dtype).min
|
min_dtype = torch.finfo(dtype).min
|
||||||
sequence_length = input_tensor.shape[1]
|
sequence_length = input_tensor.shape[1]
|
||||||
if hasattr(getattr(self.layers[0], "self_attn", {}), "past_key_value"): # static cache
|
if using_static_cache:
|
||||||
target_length = self.config.max_position_embeddings
|
target_length = past_key_values.get_max_length()
|
||||||
else: # dynamic cache
|
else:
|
||||||
target_length = (
|
target_length = (
|
||||||
attention_mask.shape[-1]
|
attention_mask.shape[-1]
|
||||||
if isinstance(attention_mask, torch.Tensor)
|
if isinstance(attention_mask, torch.Tensor)
|
||||||
@@ -1110,6 +1095,10 @@ class LlamaModel(LlamaPreTrainedModel):
|
|||||||
# backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
|
# backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
|
||||||
# cache. In that case, the 4D attention mask attends to the newest tokens only.
|
# cache. In that case, the 4D attention mask attends to the newest tokens only.
|
||||||
if attention_mask.shape[-2] < cache_position[0] + sequence_length:
|
if attention_mask.shape[-2] < cache_position[0] + sequence_length:
|
||||||
|
logger.warning_once(
|
||||||
|
"Passing a 4d mask shorter than the input length is deprecated and will be removed in "
|
||||||
|
"transformers v4.42.0"
|
||||||
|
)
|
||||||
offset = cache_position[0]
|
offset = cache_position[0]
|
||||||
else:
|
else:
|
||||||
offset = 0
|
offset = 0
|
||||||
@@ -1169,7 +1158,7 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
|
|||||||
input_ids: torch.LongTensor = None,
|
input_ids: torch.LongTensor = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
labels: Optional[torch.LongTensor] = None,
|
labels: Optional[torch.LongTensor] = None,
|
||||||
use_cache: Optional[bool] = None,
|
use_cache: Optional[bool] = None,
|
||||||
@@ -1267,13 +1256,6 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
|
|||||||
use_cache=True,
|
use_cache=True,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
# With static cache, the `past_key_values` is None
|
|
||||||
# TODO joao: standardize interface for the different Cache classes and remove of this if
|
|
||||||
has_static_cache = False
|
|
||||||
if past_key_values is None:
|
|
||||||
past_key_values = getattr(getattr(self.model.layers[0], "self_attn", {}), "past_key_value", None)
|
|
||||||
has_static_cache = past_key_values is not None
|
|
||||||
|
|
||||||
past_length = 0
|
past_length = 0
|
||||||
if past_key_values is not None:
|
if past_key_values is not None:
|
||||||
if isinstance(past_key_values, Cache):
|
if isinstance(past_key_values, Cache):
|
||||||
@@ -1291,8 +1273,7 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
|
|||||||
|
|
||||||
# Keep only the unprocessed tokens:
|
# Keep only the unprocessed tokens:
|
||||||
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
||||||
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
|
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as input)
|
||||||
# input)
|
|
||||||
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
|
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
|
||||||
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
||||||
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
|
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
|
||||||
@@ -1332,9 +1313,6 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
|
|||||||
elif use_cache:
|
elif use_cache:
|
||||||
cache_position = cache_position[-input_length:]
|
cache_position = cache_position[-input_length:]
|
||||||
|
|
||||||
if has_static_cache:
|
|
||||||
past_key_values = None
|
|
||||||
|
|
||||||
model_inputs.update(
|
model_inputs.update(
|
||||||
{
|
{
|
||||||
"position_ids": position_ids,
|
"position_ids": position_ids,
|
||||||
@@ -1393,7 +1371,7 @@ class LlamaForSequenceClassification(LlamaPreTrainedModel):
|
|||||||
input_ids: torch.LongTensor = None,
|
input_ids: torch.LongTensor = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
labels: Optional[torch.LongTensor] = None,
|
labels: Optional[torch.LongTensor] = None,
|
||||||
use_cache: Optional[bool] = None,
|
use_cache: Optional[bool] = None,
|
||||||
@@ -1510,7 +1488,7 @@ class LlamaForQuestionAnswering(LlamaPreTrainedModel):
|
|||||||
input_ids: Optional[torch.LongTensor] = None,
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
attention_mask: Optional[torch.FloatTensor] = None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
start_positions: Optional[torch.LongTensor] = None,
|
start_positions: Optional[torch.LongTensor] = None,
|
||||||
end_positions: Optional[torch.LongTensor] = None,
|
end_positions: Optional[torch.LongTensor] = None,
|
||||||
|
|||||||
@@ -1301,7 +1301,7 @@ class MistralForSequenceClassification(MistralPreTrainedModel):
|
|||||||
input_ids: torch.LongTensor = None,
|
input_ids: torch.LongTensor = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
labels: Optional[torch.LongTensor] = None,
|
labels: Optional[torch.LongTensor] = None,
|
||||||
use_cache: Optional[bool] = None,
|
use_cache: Optional[bool] = None,
|
||||||
|
|||||||
@@ -1525,7 +1525,7 @@ class MixtralForSequenceClassification(MixtralPreTrainedModel):
|
|||||||
input_ids: torch.LongTensor = None,
|
input_ids: torch.LongTensor = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
labels: Optional[torch.LongTensor] = None,
|
labels: Optional[torch.LongTensor] = None,
|
||||||
use_cache: Optional[bool] = None,
|
use_cache: Optional[bool] = None,
|
||||||
|
|||||||
@@ -692,7 +692,7 @@ class OlmoDecoderLayer(nn.Module):
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
past_key_value: Optional[Cache] = None,
|
||||||
output_attentions: Optional[bool] = False,
|
output_attentions: Optional[bool] = False,
|
||||||
use_cache: Optional[bool] = False,
|
use_cache: Optional[bool] = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
@@ -794,27 +794,6 @@ class OlmoPreTrainedModel(PreTrainedModel):
|
|||||||
if module.padding_idx is not None:
|
if module.padding_idx is not None:
|
||||||
module.weight.data[module.padding_idx].zero_()
|
module.weight.data[module.padding_idx].zero_()
|
||||||
|
|
||||||
def _setup_cache(self, cache_cls, max_batch_size, max_cache_len: Optional[int] = None):
|
|
||||||
if self.config._attn_implementation == "flash_attention_2" and cache_cls == StaticCache:
|
|
||||||
raise ValueError(
|
|
||||||
"`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
|
|
||||||
"make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
|
|
||||||
)
|
|
||||||
|
|
||||||
for layer in self.model.layers:
|
|
||||||
device = layer.input_layernorm.weight.device
|
|
||||||
if hasattr(self.config, "_pre_quantization_dtype"):
|
|
||||||
dtype = self.config._pre_quantization_dtype
|
|
||||||
else:
|
|
||||||
dtype = layer.self_attn.o_proj.weight.dtype
|
|
||||||
layer.self_attn.past_key_value = cache_cls(
|
|
||||||
self.config, max_batch_size, max_cache_len, device=device, dtype=dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
def _reset_cache(self):
|
|
||||||
for layer in self.model.layers:
|
|
||||||
layer.self_attn.past_key_value = None
|
|
||||||
|
|
||||||
|
|
||||||
OLMO_INPUTS_DOCSTRING = r"""
|
OLMO_INPUTS_DOCSTRING = r"""
|
||||||
Args:
|
Args:
|
||||||
@@ -930,7 +909,7 @@ class OlmoModel(OlmoPreTrainedModel):
|
|||||||
input_ids: torch.LongTensor = None,
|
input_ids: torch.LongTensor = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
use_cache: Optional[bool] = None,
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
@@ -959,23 +938,18 @@ class OlmoModel(OlmoPreTrainedModel):
|
|||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
inputs_embeds = self.embed_tokens(input_ids)
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
past_seen_tokens = 0
|
if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
|
||||||
if use_cache: # kept for BC (cache positions)
|
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||||
if not isinstance(past_key_values, StaticCache):
|
|
||||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
|
||||||
past_seen_tokens = past_key_values.get_seq_length()
|
|
||||||
|
|
||||||
if cache_position is None:
|
if cache_position is None:
|
||||||
if isinstance(past_key_values, StaticCache):
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||||
raise ValueError("cache_position is a required argument when using StaticCache.")
|
|
||||||
cache_position = torch.arange(
|
cache_position = torch.arange(
|
||||||
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
||||||
)
|
)
|
||||||
|
|
||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
position_ids = cache_position.unsqueeze(0)
|
position_ids = cache_position.unsqueeze(0)
|
||||||
|
|
||||||
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_seen_tokens)
|
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values)
|
||||||
|
|
||||||
# embed positions
|
# embed positions
|
||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
@@ -1028,7 +1002,9 @@ class OlmoModel(OlmoPreTrainedModel):
|
|||||||
next_cache = None
|
next_cache = None
|
||||||
if use_cache:
|
if use_cache:
|
||||||
next_cache = (
|
next_cache = (
|
||||||
next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache
|
next_decoder_cache.to_legacy_cache()
|
||||||
|
if isinstance(next_decoder_cache, DynamicCache)
|
||||||
|
else next_decoder_cache
|
||||||
)
|
)
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
||||||
@@ -1045,7 +1021,7 @@ class OlmoModel(OlmoPreTrainedModel):
|
|||||||
attention_mask: torch.Tensor,
|
attention_mask: torch.Tensor,
|
||||||
input_tensor: torch.Tensor,
|
input_tensor: torch.Tensor,
|
||||||
cache_position: torch.Tensor,
|
cache_position: torch.Tensor,
|
||||||
past_seen_tokens: int,
|
past_key_values: Cache,
|
||||||
):
|
):
|
||||||
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
|
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
|
||||||
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
|
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
|
||||||
@@ -1057,9 +1033,12 @@ class OlmoModel(OlmoPreTrainedModel):
|
|||||||
return attention_mask
|
return attention_mask
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if self.config._attn_implementation == "sdpa":
|
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||||||
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument,
|
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||||||
# in order to dispatch on Flash Attention 2.
|
# to infer the attention mask.
|
||||||
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||||
|
using_static_cache = isinstance(past_key_values, StaticCache)
|
||||||
|
if self.config._attn_implementation == "sdpa" and not using_static_cache:
|
||||||
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
||||||
attention_mask,
|
attention_mask,
|
||||||
inputs_embeds=input_tensor,
|
inputs_embeds=input_tensor,
|
||||||
@@ -1071,9 +1050,9 @@ class OlmoModel(OlmoPreTrainedModel):
|
|||||||
dtype, device = input_tensor.dtype, input_tensor.device
|
dtype, device = input_tensor.dtype, input_tensor.device
|
||||||
min_dtype = torch.finfo(dtype).min
|
min_dtype = torch.finfo(dtype).min
|
||||||
sequence_length = input_tensor.shape[1]
|
sequence_length = input_tensor.shape[1]
|
||||||
if hasattr(getattr(self.layers[0], "self_attn", {}), "past_key_value"): # static cache
|
if using_static_cache:
|
||||||
target_length = self.config.max_position_embeddings
|
target_length = past_key_values.get_max_length()
|
||||||
else: # dynamic cache
|
else:
|
||||||
target_length = (
|
target_length = (
|
||||||
attention_mask.shape[-1]
|
attention_mask.shape[-1]
|
||||||
if isinstance(attention_mask, torch.Tensor)
|
if isinstance(attention_mask, torch.Tensor)
|
||||||
@@ -1095,6 +1074,10 @@ class OlmoModel(OlmoPreTrainedModel):
|
|||||||
# backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
|
# backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
|
||||||
# cache. In that case, the 4D attention mask attends to the newest tokens only.
|
# cache. In that case, the 4D attention mask attends to the newest tokens only.
|
||||||
if attention_mask.shape[-2] < cache_position[0] + sequence_length:
|
if attention_mask.shape[-2] < cache_position[0] + sequence_length:
|
||||||
|
logger.warning_once(
|
||||||
|
"Passing a 4d mask shorter than the input length is deprecated and will be removed in "
|
||||||
|
"transformers v4.42.0"
|
||||||
|
)
|
||||||
offset = cache_position[0]
|
offset = cache_position[0]
|
||||||
else:
|
else:
|
||||||
offset = 0
|
offset = 0
|
||||||
@@ -1250,13 +1233,6 @@ class OlmoForCausalLM(OlmoPreTrainedModel):
|
|||||||
use_cache=True,
|
use_cache=True,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
# With static cache, the `past_key_values` is None
|
|
||||||
# TODO joao: standardize interface for the different Cache classes and remove of this if
|
|
||||||
has_static_cache = False
|
|
||||||
if past_key_values is None:
|
|
||||||
past_key_values = getattr(getattr(self.model.layers[0], "self_attn", {}), "past_key_value", None)
|
|
||||||
has_static_cache = past_key_values is not None
|
|
||||||
|
|
||||||
past_length = 0
|
past_length = 0
|
||||||
if past_key_values is not None:
|
if past_key_values is not None:
|
||||||
if isinstance(past_key_values, Cache):
|
if isinstance(past_key_values, Cache):
|
||||||
@@ -1274,8 +1250,7 @@ class OlmoForCausalLM(OlmoPreTrainedModel):
|
|||||||
|
|
||||||
# Keep only the unprocessed tokens:
|
# Keep only the unprocessed tokens:
|
||||||
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
||||||
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
|
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as input)
|
||||||
# input)
|
|
||||||
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
|
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
|
||||||
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
||||||
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
|
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
|
||||||
@@ -1315,9 +1290,6 @@ class OlmoForCausalLM(OlmoPreTrainedModel):
|
|||||||
elif use_cache:
|
elif use_cache:
|
||||||
cache_position = cache_position[-input_length:]
|
cache_position = cache_position[-input_length:]
|
||||||
|
|
||||||
if has_static_cache:
|
|
||||||
past_key_values = None
|
|
||||||
|
|
||||||
model_inputs.update(
|
model_inputs.update(
|
||||||
{
|
{
|
||||||
"position_ids": position_ids,
|
"position_ids": position_ids,
|
||||||
|
|||||||
@@ -927,7 +927,7 @@ class PersimmonForSequenceClassification(PersimmonPreTrainedModel):
|
|||||||
input_ids: torch.LongTensor = None,
|
input_ids: torch.LongTensor = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
labels: Optional[torch.LongTensor] = None,
|
labels: Optional[torch.LongTensor] = None,
|
||||||
use_cache: Optional[bool] = None,
|
use_cache: Optional[bool] = None,
|
||||||
|
|||||||
@@ -1313,7 +1313,7 @@ class PhiForSequenceClassification(PhiPreTrainedModel):
|
|||||||
input_ids: torch.LongTensor = None,
|
input_ids: torch.LongTensor = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
labels: Optional[torch.LongTensor] = None,
|
labels: Optional[torch.LongTensor] = None,
|
||||||
use_cache: Optional[bool] = None,
|
use_cache: Optional[bool] = None,
|
||||||
|
|||||||
@@ -1419,7 +1419,7 @@ class Phi3ForSequenceClassification(Phi3PreTrainedModel):
|
|||||||
input_ids: torch.LongTensor = None,
|
input_ids: torch.LongTensor = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
labels: Optional[torch.LongTensor] = None,
|
labels: Optional[torch.LongTensor] = None,
|
||||||
use_cache: Optional[bool] = None,
|
use_cache: Optional[bool] = None,
|
||||||
|
|||||||
@@ -1509,7 +1509,7 @@ class Qwen2MoeForSequenceClassification(Qwen2MoePreTrainedModel):
|
|||||||
input_ids: torch.LongTensor = None,
|
input_ids: torch.LongTensor = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
labels: Optional[torch.LongTensor] = None,
|
labels: Optional[torch.LongTensor] = None,
|
||||||
use_cache: Optional[bool] = None,
|
use_cache: Optional[bool] = None,
|
||||||
|
|||||||
@@ -1299,7 +1299,7 @@ class StableLmForSequenceClassification(StableLmPreTrainedModel):
|
|||||||
input_ids: torch.LongTensor = None,
|
input_ids: torch.LongTensor = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
labels: Optional[torch.LongTensor] = None,
|
labels: Optional[torch.LongTensor] = None,
|
||||||
use_cache: Optional[bool] = None,
|
use_cache: Optional[bool] = None,
|
||||||
|
|||||||
@@ -1292,7 +1292,7 @@ class Starcoder2ForSequenceClassification(Starcoder2PreTrainedModel):
|
|||||||
input_ids: torch.LongTensor = None,
|
input_ids: torch.LongTensor = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
labels: Optional[torch.LongTensor] = None,
|
labels: Optional[torch.LongTensor] = None,
|
||||||
use_cache: Optional[bool] = None,
|
use_cache: Optional[bool] = None,
|
||||||
|
|||||||
@@ -18,11 +18,11 @@ import tempfile
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from packaging import version
|
||||||
from parameterized import parameterized
|
from parameterized import parameterized
|
||||||
|
|
||||||
from transformers import LlamaConfig, StaticCache, is_torch_available, logging, set_seed
|
from transformers import LlamaConfig, is_torch_available, set_seed
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
CaptureLogger,
|
|
||||||
require_bitsandbytes,
|
require_bitsandbytes,
|
||||||
require_flash_attn,
|
require_flash_attn,
|
||||||
require_read_token,
|
require_read_token,
|
||||||
@@ -684,15 +684,28 @@ class LlamaIntegrationTest(unittest.TestCase):
|
|||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
@require_read_token
|
@require_read_token
|
||||||
def test_compile_static_cache(self):
|
def test_compile_static_cache(self):
|
||||||
|
# `torch==2.2` will throw an error on this test (as in other compilation tests), but torch==2.1.2 and torch>2.2
|
||||||
|
# work as intended. See https://github.com/pytorch/pytorch/issues/121943
|
||||||
|
if version.parse(torch.__version__) < version.parse("2.3.0"):
|
||||||
|
self.skipTest("This test requires torch >= 2.3 to run.")
|
||||||
|
|
||||||
NUM_TOKENS_TO_GENERATE = 40
|
NUM_TOKENS_TO_GENERATE = 40
|
||||||
|
# Note on `EXPECTED_TEXT_COMPLETION`'s diff: the current value matches the original test if the original test
|
||||||
|
# was changed to have a cache of 53 tokens (as opposed to 4096), on Ampere GPUs.
|
||||||
EXPECTED_TEXT_COMPLETION = {
|
EXPECTED_TEXT_COMPLETION = {
|
||||||
7: [
|
|
||||||
"Simply put, the theory of relativity states that 1) the speed of light is constant, 2) the speed of light is the same for all observers, and 3) the laws of physics are the same for all observers.",
|
|
||||||
"My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p",
|
|
||||||
],
|
|
||||||
8: [
|
8: [
|
||||||
"Simply put, the theory of relativity states that 1) the speed of light is the same for all observers, and 2) the laws of physics are the same for all observers.\nThe first part of the theory of relativity",
|
"Simply put, the theory of relativity states that 1) the speed of light is constant in all inertial "
|
||||||
"My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p",
|
"reference frames, and 2) the laws of physics are the same for all inertial reference frames.\nThe "
|
||||||
|
"theory of relativ",
|
||||||
|
"My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, "
|
||||||
|
"my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p",
|
||||||
|
],
|
||||||
|
7: [
|
||||||
|
"Simply put, the theory of relativity states that 1. surely nothing is faster than light.\nThe theory "
|
||||||
|
"goes that nothing travels faster than light, but the faster you go, the slower everything else will "
|
||||||
|
"be.\nThe theory of relativity",
|
||||||
|
"My favorite all time favorite condiment is ketchup. I love it on hamburgers, hot dogs, fries, eggs, "
|
||||||
|
"and even on a good old fashioned cheeseburger. I love it on everything. I love it so",
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -706,38 +719,25 @@ class LlamaIntegrationTest(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)
|
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)
|
||||||
|
|
||||||
def decode_one_tokens(model, cur_token, input_pos, cache_position):
|
# Dynamic Cache
|
||||||
logits = model(
|
generated_ids = model.generate(**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False)
|
||||||
cur_token, position_ids=input_pos, cache_position=cache_position, return_dict=False, use_cache=True
|
dynamic_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||||||
)[0]
|
self.assertEqual(EXPECTED_TEXT_COMPLETION[8], dynamic_text) # Both GPU architectures have the same output
|
||||||
new_token = torch.argmax(logits[:, -1], dim=-1)[:, None]
|
|
||||||
return new_token
|
|
||||||
|
|
||||||
batch_size, seq_length = inputs["input_ids"].shape
|
# Static Cache
|
||||||
with torch.no_grad():
|
generated_ids = model.generate(
|
||||||
model._setup_cache(StaticCache, 2, max_cache_len=4096)
|
**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static"
|
||||||
cache_position = torch.arange(seq_length, device=torch_device)
|
)
|
||||||
generated_ids = torch.zeros(
|
static_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||||||
batch_size, seq_length + NUM_TOKENS_TO_GENERATE + 1, dtype=torch.int, device=torch_device
|
self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], static_text)
|
||||||
)
|
|
||||||
generated_ids[:, cache_position] = inputs["input_ids"].to(torch_device).to(torch.int)
|
|
||||||
|
|
||||||
logits = model(**inputs, cache_position=cache_position, return_dict=False, use_cache=True)[0]
|
# Static Cache + compile
|
||||||
next_token = torch.argmax(logits[:, -1], dim=-1)[:, None]
|
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
|
||||||
generated_ids[:, seq_length] = next_token[:, 0]
|
generated_ids = model.generate(
|
||||||
|
**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static"
|
||||||
decode_one_tokens = torch.compile(decode_one_tokens, mode="reduce-overhead", fullgraph=True)
|
)
|
||||||
cache_position = torch.tensor([seq_length + 1], device=torch_device)
|
static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||||||
for _ in range(1, NUM_TOKENS_TO_GENERATE):
|
self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], static_compiled_text)
|
||||||
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
|
|
||||||
with CaptureLogger(logging.get_logger(__name__)) as cl:
|
|
||||||
next_token = decode_one_tokens(model, next_token.clone(), None, cache_position)
|
|
||||||
self.assertNotIn("skipping cudagraphs due to", cl.out)
|
|
||||||
generated_ids[:, cache_position] = next_token.int()
|
|
||||||
cache_position += 1
|
|
||||||
|
|
||||||
text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
|
||||||
self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], text)
|
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
|
|||||||
@@ -196,9 +196,14 @@ class AqlmTest(unittest.TestCase):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# Sample tokens greedily
|
# Sample tokens greedily
|
||||||
def decode_one_tokens(model, cur_token, input_pos, cache_position):
|
def decode_one_tokens(model, cur_token, input_pos, cache_position, past_key_values):
|
||||||
logits = model(
|
logits = model(
|
||||||
cur_token, position_ids=input_pos, cache_position=cache_position, return_dict=False, use_cache=True
|
cur_token,
|
||||||
|
position_ids=input_pos,
|
||||||
|
cache_position=cache_position,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
return_dict=False,
|
||||||
|
use_cache=True,
|
||||||
)[0]
|
)[0]
|
||||||
new_token = torch.argmax(logits[:, [-1]], dim=-1).to(torch.int)
|
new_token = torch.argmax(logits[:, [-1]], dim=-1).to(torch.int)
|
||||||
|
|
||||||
@@ -209,7 +214,13 @@ class AqlmTest(unittest.TestCase):
|
|||||||
seq_length = input_ids.shape[1]
|
seq_length = input_ids.shape[1]
|
||||||
|
|
||||||
# Setup static KV cache for generation
|
# Setup static KV cache for generation
|
||||||
self.quantized_model._setup_cache(StaticCache, 1, max_cache_len=seq_length + self.max_new_tokens + 1)
|
past_key_values = StaticCache(
|
||||||
|
config=self.quantized_model.config,
|
||||||
|
max_batch_size=1,
|
||||||
|
max_cache_len=seq_length + self.max_new_tokens + 1,
|
||||||
|
device=torch_device,
|
||||||
|
dtype=self.quantized_model.config._pre_quantization_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
# Allocate token ids to be generated and copy prefix ids
|
# Allocate token ids to be generated and copy prefix ids
|
||||||
cache_position = torch.arange(seq_length, device=torch_device)
|
cache_position = torch.arange(seq_length, device=torch_device)
|
||||||
@@ -217,7 +228,13 @@ class AqlmTest(unittest.TestCase):
|
|||||||
generated_ids[:, cache_position] = input_ids.to(torch_device).to(torch.int)
|
generated_ids[:, cache_position] = input_ids.to(torch_device).to(torch.int)
|
||||||
|
|
||||||
# Do a forward pass to fill the prefix cache and compile the kernels if necessary
|
# Do a forward pass to fill the prefix cache and compile the kernels if necessary
|
||||||
logits = self.quantized_model(input_ids, cache_position=cache_position, return_dict=False, use_cache=True)[0]
|
logits = self.quantized_model(
|
||||||
|
input_ids,
|
||||||
|
cache_position=cache_position,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
return_dict=False,
|
||||||
|
use_cache=True,
|
||||||
|
)[0]
|
||||||
next_token = torch.argmax(logits[:, [-1]], dim=-1).to(torch.int)
|
next_token = torch.argmax(logits[:, [-1]], dim=-1).to(torch.int)
|
||||||
generated_ids[:, [seq_length]] = next_token
|
generated_ids[:, [seq_length]] = next_token
|
||||||
|
|
||||||
@@ -229,7 +246,9 @@ class AqlmTest(unittest.TestCase):
|
|||||||
cache_position = torch.tensor([seq_length + 1], device=torch_device)
|
cache_position = torch.tensor([seq_length + 1], device=torch_device)
|
||||||
for _ in range(1, self.max_new_tokens):
|
for _ in range(1, self.max_new_tokens):
|
||||||
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
|
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
|
||||||
next_token = decode_one_tokens(self.quantized_model, next_token.clone(), None, cache_position)
|
next_token = decode_one_tokens(
|
||||||
|
self.quantized_model, next_token.clone(), None, cache_position, past_key_values
|
||||||
|
)
|
||||||
generated_ids.index_copy_(1, cache_position, next_token)
|
generated_ids.index_copy_(1, cache_position, next_token)
|
||||||
cache_position += 1
|
cache_position += 1
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user