Generate: New Cache abstraction and Attention Sinks support (#26681)
* Draft version of new KV Caching This should allow Attention Sinks (https://github.com/tomaarsen/attention_sinks) / StreamingLLM (https://arxiv.org/abs/2309.17453) to be easily implemented in a third-party or in transformers directly * Address numerous PR suggestions 1. Move layer_idx from cache to ...Attention. Removes confusing set_layer_idx magic. 2. Always convert past_key_values to Cache instance at the start of ...Attention, removes all other isinstance calls. 3. Remove __bool__ and __getitem__ magic as they're confusing. 4. past_key_values.update(key, value, idx) now returns key, value. 5. Add use_legacy_cache flag, defaults to None, i.e. Falsey. This breaks generate for now, until 1) the cache is used is generate() or 2) use_legacy_cache is defaulted to True in generate() until we change it in another PR. 6. Separate key_cache and value_cache. Some work is still needed to see if the SinkCache can conveniently be implemented with just one update method. * Implement the SinkCache through backward+forward rotations * Integrate (Sink)Cache with Llama FA2 * Set use_legacy_cache=True as default, allows for test passes * Move from/to_legacy_cache to ...Model class * Undo unnecessary newline change * Remove copy utility from deprecated OpenLlama * Match import style * manual rebase with main * Cache class working with generate (#1) * Draft version of new KV Caching This should allow Attention Sinks (https://github.com/tomaarsen/attention_sinks) / StreamingLLM (https://arxiv.org/abs/2309.17453) to be easily implemented in a third-party or in transformers directly * Address numerous PR suggestions 1. Move layer_idx from cache to ...Attention. Removes confusing set_layer_idx magic. 2. Always convert past_key_values to Cache instance at the start of ...Attention, removes all other isinstance calls. 3. Remove __bool__ and __getitem__ magic as they're confusing. 4. past_key_values.update(key, value, idx) now returns key, value. 5. Add use_legacy_cache flag, defaults to None, i.e. Falsey. This breaks generate for now, until 1) the cache is used is generate() or 2) use_legacy_cache is defaulted to True in generate() until we change it in another PR. 6. Separate key_cache and value_cache. Some work is still needed to see if the SinkCache can conveniently be implemented with just one update method. * Integrate (Sink)Cache with Llama FA2 * Move from/to_legacy_cache to ...Model class * Undo unnecessary newline change * Match import style * working generate * Add tests; Simplify code; Apply changes to Mistral and Persimmon * fix rebase mess * a few more manual fixes * last manual fix * propagate changes to phi * upgrade test * add use_legacy_cache docstring; beef up tests * reintroduce unwanted deletes --------- Co-authored-by: Tom Aarsen <Cubiegamedev@gmail.com> * move import * add default to model_kwargs.get('use_legacy_cache') * correct failing test * Apply suggestions from code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * apply PR suggestions * fix failing test * Apply suggestions from code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Tom Aarsen <37621491+tomaarsen@users.noreply.github.com> * PR comments * tmp commit * add docstrings * more tests, more docstrings, add to docs * derp * tmp commit * tmp dbg * more dbg * fix beam search bug * cache can be a list of tuples in some models * fix group beam search * all but sinkcache integration tests * fix sink cache and add hard integration test * now also compatible with input_embeds input * PR comments * add Cache support to Phi+FA2 * make fixup --------- Co-authored-by: Joao Gante <joao@huggingface.co> Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -368,3 +368,20 @@ A [`Constraint`] can be used to force the generation to include specific tokens
|
|||||||
[[autodoc]] TextStreamer
|
[[autodoc]] TextStreamer
|
||||||
|
|
||||||
[[autodoc]] TextIteratorStreamer
|
[[autodoc]] TextIteratorStreamer
|
||||||
|
|
||||||
|
## Caches
|
||||||
|
|
||||||
|
[[autodoc]] Cache
|
||||||
|
- update
|
||||||
|
|
||||||
|
[[autodoc]] DynamicCache
|
||||||
|
- update
|
||||||
|
- get_seq_length
|
||||||
|
- reorder_cache
|
||||||
|
- to_legacy_cache
|
||||||
|
- from_legacy_cache
|
||||||
|
|
||||||
|
[[autodoc]] SinkCache
|
||||||
|
- update
|
||||||
|
- get_seq_length
|
||||||
|
- reorder_cache
|
||||||
|
|||||||
@@ -1303,6 +1303,7 @@ else:
|
|||||||
_import_structure["activations"] = []
|
_import_structure["activations"] = []
|
||||||
_import_structure["benchmark.benchmark"] = ["PyTorchBenchmark"]
|
_import_structure["benchmark.benchmark"] = ["PyTorchBenchmark"]
|
||||||
_import_structure["benchmark.benchmark_args"] = ["PyTorchBenchmarkArguments"]
|
_import_structure["benchmark.benchmark_args"] = ["PyTorchBenchmarkArguments"]
|
||||||
|
_import_structure["cache_utils"] = ["Cache", "DynamicCache", "SinkCache"]
|
||||||
_import_structure["data.datasets"] = [
|
_import_structure["data.datasets"] = [
|
||||||
"GlueDataset",
|
"GlueDataset",
|
||||||
"GlueDataTrainingArguments",
|
"GlueDataTrainingArguments",
|
||||||
@@ -5945,6 +5946,7 @@ if TYPE_CHECKING:
|
|||||||
# Benchmarks
|
# Benchmarks
|
||||||
from .benchmark.benchmark import PyTorchBenchmark
|
from .benchmark.benchmark import PyTorchBenchmark
|
||||||
from .benchmark.benchmark_args import PyTorchBenchmarkArguments
|
from .benchmark.benchmark_args import PyTorchBenchmarkArguments
|
||||||
|
from .cache_utils import Cache, DynamicCache, SinkCache
|
||||||
from .data.datasets import (
|
from .data.datasets import (
|
||||||
GlueDataset,
|
GlueDataset,
|
||||||
GlueDataTrainingArguments,
|
GlueDataTrainingArguments,
|
||||||
|
|||||||
298
src/transformers/cache_utils.py
Normal file
298
src/transformers/cache_utils.py
Normal file
@@ -0,0 +1,298 @@
|
|||||||
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class Cache:
|
||||||
|
"""
|
||||||
|
Base, abstract class for all caches. The actual data structure is specific to each subclass.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def update(
|
||||||
|
self,
|
||||||
|
key_states: torch.Tensor,
|
||||||
|
value_states: torch.Tensor,
|
||||||
|
layer_idx: int,
|
||||||
|
cache_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
key_states (`torch.Tensor`):
|
||||||
|
The new key states to cache.
|
||||||
|
value_states (`torch.Tensor`):
|
||||||
|
The new value states to cache.
|
||||||
|
layer_idx (`int`):
|
||||||
|
The index of the layer to cache the states for.
|
||||||
|
cache_kwargs (`Dict[str, Any]`, `optional`):
|
||||||
|
Additional arguments for the cache subclass. These are specific to each subclass and allow new types of
|
||||||
|
cache to be created.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
A tuple containing the updated key and value states.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("Make sure to implement `update` in a subclass.")
|
||||||
|
|
||||||
|
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
||||||
|
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
|
||||||
|
raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.")
|
||||||
|
|
||||||
|
|
||||||
|
class DynamicCache(Cache):
|
||||||
|
"""
|
||||||
|
A cache that grows dynamically as more tokens are generated. This is the default for generative models.
|
||||||
|
|
||||||
|
It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
|
||||||
|
`[batch_size, num_heads, seq_len, head_dim]`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.key_cache: List[torch.Tensor] = []
|
||||||
|
self.value_cache: List[torch.Tensor] = []
|
||||||
|
self.seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
|
||||||
|
|
||||||
|
def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
|
||||||
|
"""
|
||||||
|
Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the
|
||||||
|
sequence length.
|
||||||
|
"""
|
||||||
|
if layer_idx < len(self):
|
||||||
|
return (self.key_cache[layer_idx], self.value_cache[layer_idx])
|
||||||
|
else:
|
||||||
|
raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
"""
|
||||||
|
Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over
|
||||||
|
keys and values
|
||||||
|
"""
|
||||||
|
for layer_idx in range(len(self)):
|
||||||
|
yield (self.key_cache[layer_idx], self.value_cache[layer_idx])
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
"""
|
||||||
|
Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds
|
||||||
|
to the number of layers in the model.
|
||||||
|
"""
|
||||||
|
return len(self.key_cache)
|
||||||
|
|
||||||
|
def update(
|
||||||
|
self,
|
||||||
|
key_states: torch.Tensor,
|
||||||
|
value_states: torch.Tensor,
|
||||||
|
layer_idx: int,
|
||||||
|
cache_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
key_states (`torch.Tensor`):
|
||||||
|
The new key states to cache.
|
||||||
|
value_states (`torch.Tensor`):
|
||||||
|
The new value states to cache.
|
||||||
|
layer_idx (`int`):
|
||||||
|
The index of the layer to cache the states for.
|
||||||
|
cache_kwargs (`Dict[str, Any]`, `optional`):
|
||||||
|
Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
A tuple containing the updated key and value states.
|
||||||
|
"""
|
||||||
|
# Update the number of seen tokens
|
||||||
|
if layer_idx == 0:
|
||||||
|
self.seen_tokens += key_states.shape[-2]
|
||||||
|
|
||||||
|
# Update the cache
|
||||||
|
if len(self.key_cache) <= layer_idx:
|
||||||
|
self.key_cache.append(key_states)
|
||||||
|
self.value_cache.append(value_states)
|
||||||
|
else:
|
||||||
|
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
|
||||||
|
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
|
||||||
|
|
||||||
|
return self.key_cache[layer_idx], self.value_cache[layer_idx]
|
||||||
|
|
||||||
|
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
||||||
|
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
|
||||||
|
if len(self.key_cache) <= layer_idx:
|
||||||
|
return 0
|
||||||
|
return self.key_cache[layer_idx].shape[-2]
|
||||||
|
|
||||||
|
def reorder_cache(self, beam_idx: torch.LongTensor):
|
||||||
|
"""Reorders the cache for beam search, given the selected beam indices."""
|
||||||
|
for layer_idx in range(len(self.key_cache)):
|
||||||
|
device = self.key_cache[layer_idx].device
|
||||||
|
self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
|
||||||
|
device = self.value_cache[layer_idx].device
|
||||||
|
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
|
||||||
|
|
||||||
|
def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
|
||||||
|
"""Converts the `DynamicCache` instance into the its equivalent in the legacy cache format."""
|
||||||
|
legacy_cache = ()
|
||||||
|
for layer_idx in range(len(self)):
|
||||||
|
legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),)
|
||||||
|
return legacy_cache
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache":
|
||||||
|
"""Converts a cache in the legacy cache format into an equivalent `DynamicCache`."""
|
||||||
|
cache = cls()
|
||||||
|
if past_key_values is not None:
|
||||||
|
for layer_idx in range(len(past_key_values)):
|
||||||
|
key_states, value_states = past_key_values[layer_idx]
|
||||||
|
cache.update(key_states, value_states, layer_idx)
|
||||||
|
return cache
|
||||||
|
|
||||||
|
|
||||||
|
class SinkCache(Cache):
|
||||||
|
"""
|
||||||
|
A cache that as described in the [Attention Sinks paper](https://arxiv.org/abs/2309.17453). It allows the model to
|
||||||
|
generate beyond the length of its context window, without losing fluency in the conversation. As it discards past
|
||||||
|
tokens, the model will lose the ability to generate tokens that depend on the context that was discarded.
|
||||||
|
|
||||||
|
It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
|
||||||
|
`[batch_size, num_heads, seq_len, head_dim]`.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
window_length (`int`):
|
||||||
|
The length of the context window.
|
||||||
|
num_sink_tokens (`int`):
|
||||||
|
The number of sink tokens. See the original paper for more information.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, window_length: int, num_sink_tokens: int) -> None:
|
||||||
|
self.key_cache: List[torch.Tensor] = []
|
||||||
|
self.value_cache: List[torch.Tensor] = []
|
||||||
|
self.window_length = window_length
|
||||||
|
self.num_sink_tokens = num_sink_tokens
|
||||||
|
self.cos_sin_cache = {}
|
||||||
|
self.seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _rotate_half(x):
|
||||||
|
x1 = x[..., : x.shape[-1] // 2]
|
||||||
|
x2 = x[..., x.shape[-1] // 2 :]
|
||||||
|
return torch.cat((-x2, x1), dim=-1)
|
||||||
|
|
||||||
|
def _apply_key_rotary_pos_emb(
|
||||||
|
self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
|
||||||
|
) -> torch.Tensor:
|
||||||
|
rotated_key_states = (key_states * cos) + (self._rotate_half(key_states) * sin)
|
||||||
|
return rotated_key_states
|
||||||
|
|
||||||
|
def _get_rerotation_cos_sin(
|
||||||
|
self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
if key_states.shape[-2] not in self.cos_sin_cache:
|
||||||
|
# Upcast to float32 temporarily for better accuracy
|
||||||
|
cos = cos.to(torch.float32)
|
||||||
|
sin = sin.to(torch.float32)
|
||||||
|
|
||||||
|
# Compute the cos and sin required for back- and forward-rotating to one position earlier in the sequence
|
||||||
|
original_cos = cos[self.num_sink_tokens + key_states.shape[-2] :]
|
||||||
|
shifted_cos = cos[self.num_sink_tokens : -key_states.shape[-2]]
|
||||||
|
original_sin = sin[self.num_sink_tokens + key_states.shape[-2] :]
|
||||||
|
shifted_sin = sin[self.num_sink_tokens : -key_states.shape[-2]]
|
||||||
|
rerotation_cos = original_cos * shifted_cos + original_sin * shifted_sin
|
||||||
|
rerotation_sin = -original_sin * shifted_cos + original_cos * shifted_sin
|
||||||
|
|
||||||
|
self.cos_sin_cache[key_states.shape[-2]] = (
|
||||||
|
rerotation_cos.to(key_states.dtype).unsqueeze(0),
|
||||||
|
rerotation_sin.to(key_states.dtype).unsqueeze(0),
|
||||||
|
)
|
||||||
|
return self.cos_sin_cache[key_states.shape[-2]]
|
||||||
|
|
||||||
|
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
||||||
|
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
|
||||||
|
# Workaround to make 'key_states.shape[-2] + past_key_value.get_seq_length(self.layer_idx)' <= window_length
|
||||||
|
if len(self.key_cache) <= layer_idx:
|
||||||
|
return 0
|
||||||
|
cache_length = self.key_cache[layer_idx].shape[-2]
|
||||||
|
return min(cache_length, self.window_length - 1)
|
||||||
|
|
||||||
|
def update(
|
||||||
|
self,
|
||||||
|
key_states: torch.Tensor,
|
||||||
|
value_states: torch.Tensor,
|
||||||
|
layer_idx: int,
|
||||||
|
cache_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
key_states (`torch.Tensor`):
|
||||||
|
The new key states to cache.
|
||||||
|
value_states (`torch.Tensor`):
|
||||||
|
The new value states to cache.
|
||||||
|
layer_idx (`int`):
|
||||||
|
The index of the layer to cache the states for.
|
||||||
|
cache_kwargs (`Dict[str, Any]`, `optional`):
|
||||||
|
Additional arguments for the cache subclass. The following arguments can be used in `SinkCache`: `sin`,
|
||||||
|
`cos` and `partial_rotation_size`. These arguments are used with models using RoPE, to recompute the
|
||||||
|
rotation as the tokens are shifted.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
A tuple containing the updated key and value states.
|
||||||
|
"""
|
||||||
|
# Optional kwargs for `SinkCache` -- needed on models using RoPE. `partial_rotation_size` is used on models
|
||||||
|
# with partially rotated position embeddings, like Phi or Persimmon.
|
||||||
|
sin = cache_kwargs.get("sin")
|
||||||
|
cos = cache_kwargs.get("cos")
|
||||||
|
partial_rotation_size = cache_kwargs.get("partial_rotation_size")
|
||||||
|
using_rope = cos is not None and sin is not None
|
||||||
|
|
||||||
|
# Update the number of seen tokens
|
||||||
|
if layer_idx == 0:
|
||||||
|
self.seen_tokens += key_states.shape[-2]
|
||||||
|
|
||||||
|
# [bsz, num_heads, seq_len, head_dim]
|
||||||
|
if len(self.key_cache) <= layer_idx:
|
||||||
|
# Empty cache
|
||||||
|
self.key_cache.append(key_states)
|
||||||
|
self.value_cache.append(value_states)
|
||||||
|
|
||||||
|
elif key_states.shape[-2] + self.get_seq_length(layer_idx) < self.window_length:
|
||||||
|
# Growing cache
|
||||||
|
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
|
||||||
|
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Shifting cache
|
||||||
|
keys_to_keep = self.key_cache[layer_idx][
|
||||||
|
:, :, -self.window_length + self.num_sink_tokens + key_states.shape[-2] :
|
||||||
|
]
|
||||||
|
|
||||||
|
# On RoPE models, we need to recompute the Key rotation as the tokens are shifted
|
||||||
|
if using_rope:
|
||||||
|
rerotation_cos, rerotation_sin = self._get_rerotation_cos_sin(key_states, cos, sin)
|
||||||
|
if partial_rotation_size is not None:
|
||||||
|
keys_to_keep, keys_pass = (
|
||||||
|
keys_to_keep[..., :partial_rotation_size],
|
||||||
|
keys_to_keep[..., partial_rotation_size:],
|
||||||
|
)
|
||||||
|
keys_to_keep = self._apply_key_rotary_pos_emb(keys_to_keep, rerotation_cos, rerotation_sin)
|
||||||
|
if partial_rotation_size is not None:
|
||||||
|
keys_to_keep = torch.cat((keys_to_keep, keys_pass), dim=-1)
|
||||||
|
|
||||||
|
# Concatenate sink tokens, shifted & rotated tokens (if needed), and new tokens
|
||||||
|
sink_keys = self.key_cache[layer_idx][:, :, : self.num_sink_tokens]
|
||||||
|
self.key_cache[layer_idx] = torch.cat([sink_keys, keys_to_keep, key_states], dim=-2)
|
||||||
|
|
||||||
|
sink_values = self.value_cache[layer_idx][:, :, : self.num_sink_tokens]
|
||||||
|
values_to_keep = self.value_cache[layer_idx][
|
||||||
|
:, :, -self.window_length + self.num_sink_tokens + value_states.shape[-2] :
|
||||||
|
]
|
||||||
|
self.value_cache[layer_idx] = torch.cat([sink_values, values_to_keep, value_states], dim=-2)
|
||||||
|
|
||||||
|
return self.key_cache[layer_idx], self.value_cache[layer_idx]
|
||||||
|
|
||||||
|
def reorder_cache(self, beam_idx: torch.LongTensor):
|
||||||
|
"""Reorders the cache for beam search, given the selected beam indices."""
|
||||||
|
for layer_idx in range(len(self.key_cache)):
|
||||||
|
device = self.key_cache[layer_idx].device
|
||||||
|
self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
|
||||||
|
device = self.value_cache[layer_idx].device
|
||||||
|
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
|
||||||
@@ -24,6 +24,7 @@ import torch
|
|||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
|
from ..cache_utils import Cache, DynamicCache
|
||||||
from ..integrations.deepspeed import is_deepspeed_zero3_enabled
|
from ..integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||||
from ..modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput
|
from ..modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput
|
||||||
from ..models.auto import (
|
from ..models.auto import (
|
||||||
@@ -1287,6 +1288,13 @@ class GenerationMixin:
|
|||||||
|
|
||||||
def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]):
|
def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]):
|
||||||
"""Validates model kwargs for generation. Generate argument typos will also be caught here."""
|
"""Validates model kwargs for generation. Generate argument typos will also be caught here."""
|
||||||
|
# If a `Cache` instance is passed, checks whether the model is compatible with it
|
||||||
|
if isinstance(model_kwargs.get("past_key_values", None), Cache) and not self._supports_cache_class:
|
||||||
|
raise ValueError(
|
||||||
|
f"{self.__class__.__name__} does not support an instance of `Cache` as `past_key_values`. Please "
|
||||||
|
"check the model documentation for supported cache formats."
|
||||||
|
)
|
||||||
|
|
||||||
# Excludes arguments that are handled before calling any model function
|
# Excludes arguments that are handled before calling any model function
|
||||||
if self.config.is_encoder_decoder:
|
if self.config.is_encoder_decoder:
|
||||||
for key in ["decoder_input_ids"]:
|
for key in ["decoder_input_ids"]:
|
||||||
@@ -2945,6 +2953,32 @@ class GenerationMixin:
|
|||||||
else:
|
else:
|
||||||
return input_ids
|
return input_ids
|
||||||
|
|
||||||
|
def _temporary_reorder_cache(self, past_key_values, beam_idx):
|
||||||
|
"""
|
||||||
|
Temporary function to handle the different types of cache reordering processes while we roll out `Cache`.
|
||||||
|
|
||||||
|
TODO: standardize cache formats and make all models compatible with `Cache`. It would remove the need
|
||||||
|
for this function, with `Cache.reorder_cache` being the sole remaining code path
|
||||||
|
"""
|
||||||
|
model_class = self.__class__.__name__.lower()
|
||||||
|
# Exception 1: code path for models using the legacy cache format
|
||||||
|
if isinstance(past_key_values, (tuple, list)):
|
||||||
|
past_key_values = self._reorder_cache(past_key_values, beam_idx)
|
||||||
|
# Exception 2: models with different cache formats. These are limited to `DynamicCache` until their
|
||||||
|
# cache format is standardized, to avoid adding complexity to the codebase.
|
||||||
|
elif "bloom" in model_class or "gptbigcode" in model_class:
|
||||||
|
if not isinstance(past_key_values, DynamicCache):
|
||||||
|
raise ValueError(
|
||||||
|
f"Using an unsupported cache format with {model_class}. Currently, it only supports the "
|
||||||
|
"legacy tuple format or `DynamicCache`"
|
||||||
|
)
|
||||||
|
past_key_values = self._reorder_cache(past_key_values, beam_idx)
|
||||||
|
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||||
|
# Standard code path: use the `Cache.reorder_cache`
|
||||||
|
else:
|
||||||
|
past_key_values.reorder_cache(beam_idx)
|
||||||
|
return past_key_values
|
||||||
|
|
||||||
def beam_search(
|
def beam_search(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.LongTensor,
|
input_ids: torch.LongTensor,
|
||||||
@@ -3218,7 +3252,9 @@ class GenerationMixin:
|
|||||||
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
|
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
|
||||||
)
|
)
|
||||||
if model_kwargs["past_key_values"] is not None:
|
if model_kwargs["past_key_values"] is not None:
|
||||||
model_kwargs["past_key_values"] = self._reorder_cache(model_kwargs["past_key_values"], beam_idx)
|
model_kwargs["past_key_values"] = self._temporary_reorder_cache(
|
||||||
|
model_kwargs["past_key_values"], beam_idx
|
||||||
|
)
|
||||||
|
|
||||||
if return_dict_in_generate and output_scores:
|
if return_dict_in_generate and output_scores:
|
||||||
beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices))))
|
beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices))))
|
||||||
@@ -3553,7 +3589,9 @@ class GenerationMixin:
|
|||||||
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
|
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
|
||||||
)
|
)
|
||||||
if model_kwargs["past_key_values"] is not None:
|
if model_kwargs["past_key_values"] is not None:
|
||||||
model_kwargs["past_key_values"] = self._reorder_cache(model_kwargs["past_key_values"], beam_idx)
|
model_kwargs["past_key_values"] = self._temporary_reorder_cache(
|
||||||
|
model_kwargs["past_key_values"], beam_idx
|
||||||
|
)
|
||||||
|
|
||||||
if return_dict_in_generate and output_scores:
|
if return_dict_in_generate and output_scores:
|
||||||
beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices))))
|
beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices))))
|
||||||
@@ -3938,7 +3976,7 @@ class GenerationMixin:
|
|||||||
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
|
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
|
||||||
)
|
)
|
||||||
if model_kwargs["past_key_values"] is not None:
|
if model_kwargs["past_key_values"] is not None:
|
||||||
model_kwargs["past_key_values"] = self._reorder_cache(
|
model_kwargs["past_key_values"] = self._temporary_reorder_cache(
|
||||||
model_kwargs["past_key_values"], reordering_indices
|
model_kwargs["past_key_values"], reordering_indices
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -4280,7 +4318,9 @@ class GenerationMixin:
|
|||||||
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
|
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
|
||||||
)
|
)
|
||||||
if model_kwargs["past_key_values"] is not None:
|
if model_kwargs["past_key_values"] is not None:
|
||||||
model_kwargs["past_key_values"] = self._reorder_cache(model_kwargs["past_key_values"], beam_idx)
|
model_kwargs["past_key_values"] = self._temporary_reorder_cache(
|
||||||
|
model_kwargs["past_key_values"], beam_idx
|
||||||
|
)
|
||||||
|
|
||||||
if return_dict_in_generate and output_scores:
|
if return_dict_in_generate and output_scores:
|
||||||
beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices))))
|
beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices))))
|
||||||
|
|||||||
@@ -1128,6 +1128,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
# Flash Attention 2 support
|
# Flash Attention 2 support
|
||||||
_supports_flash_attn_2 = False
|
_supports_flash_attn_2 = False
|
||||||
|
|
||||||
|
# Has support for a `Cache` instance as `past_key_values`
|
||||||
|
_supports_cache_class = False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dummy_inputs(self) -> Dict[str, torch.Tensor]:
|
def dummy_inputs(self) -> Dict[str, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -860,7 +860,6 @@ class OpenLlamaForCausalLM(OpenLlamaPreTrainedModel):
|
|||||||
""",
|
""",
|
||||||
OPEN_LLAMA_START_DOCSTRING,
|
OPEN_LLAMA_START_DOCSTRING,
|
||||||
)
|
)
|
||||||
# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with LLAMA->OPEN_LLAMA,Llama->OpenLlama
|
|
||||||
class OpenLlamaForSequenceClassification(OpenLlamaPreTrainedModel):
|
class OpenLlamaForSequenceClassification(OpenLlamaPreTrainedModel):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ from torch import nn
|
|||||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||||
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
|
from ...cache_utils import Cache, DynamicCache
|
||||||
from ...modeling_attn_mask_utils import (
|
from ...modeling_attn_mask_utils import (
|
||||||
AttentionMaskConverter,
|
AttentionMaskConverter,
|
||||||
_prepare_4d_attention_mask,
|
_prepare_4d_attention_mask,
|
||||||
@@ -283,9 +284,17 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
|||||||
class LlamaAttention(nn.Module):
|
class LlamaAttention(nn.Module):
|
||||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||||
|
|
||||||
def __init__(self, config: LlamaConfig):
|
def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.layer_idx = layer_idx
|
||||||
|
if layer_idx is None:
|
||||||
|
logger.warning_once(
|
||||||
|
f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
|
||||||
|
"to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
|
||||||
|
"when creating this class."
|
||||||
|
)
|
||||||
|
|
||||||
self.attention_dropout = config.attention_dropout
|
self.attention_dropout = config.attention_dropout
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
self.num_heads = config.num_attention_heads
|
self.num_heads = config.num_attention_heads
|
||||||
@@ -343,7 +352,7 @@ class LlamaAttention(nn.Module):
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
past_key_value: Optional[Cache] = None,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@@ -383,16 +392,19 @@ class LlamaAttention(nn.Module):
|
|||||||
|
|
||||||
kv_seq_len = key_states.shape[-2]
|
kv_seq_len = key_states.shape[-2]
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
kv_seq_len += past_key_value[0].shape[-2]
|
if self.layer_idx is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
||||||
|
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
||||||
|
"with a layer index."
|
||||||
|
)
|
||||||
|
kv_seq_len += past_key_value.get_seq_length(self.layer_idx)
|
||||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||||
|
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
# reuse k, v, self_attention
|
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
|
||||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
|
||||||
|
|
||||||
past_key_value = (key_states, value_states) if use_cache else None
|
|
||||||
|
|
||||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||||
@@ -460,7 +472,7 @@ class LlamaFlashAttention2(LlamaAttention):
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
attention_mask: Optional[torch.LongTensor] = None,
|
attention_mask: Optional[torch.LongTensor] = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
past_key_value: Optional[Cache] = None,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@@ -491,18 +503,14 @@ class LlamaFlashAttention2(LlamaAttention):
|
|||||||
|
|
||||||
kv_seq_len = key_states.shape[-2]
|
kv_seq_len = key_states.shape[-2]
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
kv_seq_len += past_key_value[0].shape[-2]
|
kv_seq_len += past_key_value.get_seq_length(self.layer_idx)
|
||||||
|
|
||||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||||
|
|
||||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||||
|
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
# reuse k, v, self_attention
|
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
|
||||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
|
||||||
|
|
||||||
past_key_value = (key_states, value_states) if use_cache else None
|
|
||||||
|
|
||||||
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
|
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
|
||||||
# to be able to avoid many of these transpose/reshape/view.
|
# to be able to avoid many of these transpose/reshape/view.
|
||||||
@@ -647,13 +655,13 @@ class LlamaFlashAttention2(LlamaAttention):
|
|||||||
|
|
||||||
|
|
||||||
class LlamaDecoderLayer(nn.Module):
|
class LlamaDecoderLayer(nn.Module):
|
||||||
def __init__(self, config: LlamaConfig):
|
def __init__(self, config: LlamaConfig, layer_idx: int):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
self.self_attn = (
|
self.self_attn = (
|
||||||
LlamaAttention(config=config)
|
LlamaAttention(config=config, layer_idx=layer_idx)
|
||||||
if not getattr(config, "_flash_attn_2_enabled", False)
|
if not getattr(config, "_flash_attn_2_enabled", False)
|
||||||
else LlamaFlashAttention2(config=config)
|
else LlamaFlashAttention2(config=config, layer_idx=layer_idx)
|
||||||
)
|
)
|
||||||
self.mlp = LlamaMLP(config)
|
self.mlp = LlamaMLP(config)
|
||||||
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
@@ -749,6 +757,7 @@ class LlamaPreTrainedModel(PreTrainedModel):
|
|||||||
_no_split_modules = ["LlamaDecoderLayer"]
|
_no_split_modules = ["LlamaDecoderLayer"]
|
||||||
_skip_keys_device_placement = "past_key_values"
|
_skip_keys_device_placement = "past_key_values"
|
||||||
_supports_flash_attn_2 = True
|
_supports_flash_attn_2 = True
|
||||||
|
_supports_cache_class = True
|
||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
std = self.config.initializer_range
|
std = self.config.initializer_range
|
||||||
@@ -797,13 +806,19 @@ LLAMA_INPUTS_DOCSTRING = r"""
|
|||||||
config.n_positions - 1]`.
|
config.n_positions - 1]`.
|
||||||
|
|
||||||
[What are position IDs?](../glossary#position-ids)
|
[What are position IDs?](../glossary#position-ids)
|
||||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
|
||||||
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
||||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
|
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
|
||||||
`(batch_size, num_heads, decoder_sequence_length, embed_size_per_head)`.
|
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
|
||||||
|
|
||||||
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
Two formats are allowed:
|
||||||
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
- a [`~cache_utils.Cache`] instance;
|
||||||
|
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
|
||||||
|
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
|
||||||
|
cache format.
|
||||||
|
|
||||||
|
The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
|
||||||
|
legacy cache format will be returned.
|
||||||
|
|
||||||
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
|
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
|
||||||
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
|
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
|
||||||
@@ -844,7 +859,9 @@ class LlamaModel(LlamaPreTrainedModel):
|
|||||||
self.vocab_size = config.vocab_size
|
self.vocab_size = config.vocab_size
|
||||||
|
|
||||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
||||||
self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
|
self.layers = nn.ModuleList(
|
||||||
|
[LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
||||||
|
)
|
||||||
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
|
||||||
self.gradient_checkpointing = False
|
self.gradient_checkpointing = False
|
||||||
@@ -889,8 +906,11 @@ class LlamaModel(LlamaPreTrainedModel):
|
|||||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||||
|
|
||||||
past_key_values_length = 0
|
past_key_values_length = 0
|
||||||
if past_key_values is not None:
|
if use_cache:
|
||||||
past_key_values_length = past_key_values[0][0].shape[2]
|
use_legacy_cache = not isinstance(past_key_values, Cache)
|
||||||
|
if use_legacy_cache:
|
||||||
|
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||||
|
past_key_values_length = past_key_values.get_seq_length()
|
||||||
|
|
||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||||
@@ -924,21 +944,19 @@ class LlamaModel(LlamaPreTrainedModel):
|
|||||||
# decoder layers
|
# decoder layers
|
||||||
all_hidden_states = () if output_hidden_states else None
|
all_hidden_states = () if output_hidden_states else None
|
||||||
all_self_attns = () if output_attentions else None
|
all_self_attns = () if output_attentions else None
|
||||||
next_decoder_cache = () if use_cache else None
|
next_decoder_cache = None
|
||||||
|
|
||||||
for idx, decoder_layer in enumerate(self.layers):
|
for decoder_layer in self.layers:
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states += (hidden_states,)
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
if self.gradient_checkpointing and self.training:
|
||||||
layer_outputs = self._gradient_checkpointing_func(
|
layer_outputs = self._gradient_checkpointing_func(
|
||||||
decoder_layer.__call__,
|
decoder_layer.__call__,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
position_ids,
|
position_ids,
|
||||||
past_key_value,
|
past_key_values,
|
||||||
output_attentions,
|
output_attentions,
|
||||||
use_cache,
|
use_cache,
|
||||||
)
|
)
|
||||||
@@ -947,7 +965,7 @@ class LlamaModel(LlamaPreTrainedModel):
|
|||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
past_key_value=past_key_value,
|
past_key_value=past_key_values,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
)
|
)
|
||||||
@@ -955,7 +973,7 @@ class LlamaModel(LlamaPreTrainedModel):
|
|||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
if use_cache:
|
if use_cache:
|
||||||
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
||||||
|
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
all_self_attns += (layer_outputs[1],)
|
all_self_attns += (layer_outputs[1],)
|
||||||
@@ -966,7 +984,9 @@ class LlamaModel(LlamaPreTrainedModel):
|
|||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states += (hidden_states,)
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
next_cache = next_decoder_cache if use_cache else None
|
next_cache = None
|
||||||
|
if use_cache:
|
||||||
|
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
||||||
return BaseModelOutputWithPast(
|
return BaseModelOutputWithPast(
|
||||||
@@ -1047,7 +1067,6 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
|
|||||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||||
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
||||||
```"""
|
```"""
|
||||||
|
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
output_hidden_states = (
|
output_hidden_states = (
|
||||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
@@ -1105,16 +1124,28 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
|
|||||||
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
|
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
|
||||||
):
|
):
|
||||||
if past_key_values is not None:
|
if past_key_values is not None:
|
||||||
past_length = past_key_values[0][0].shape[2]
|
if isinstance(past_key_values, Cache):
|
||||||
|
cache_length = past_key_values.get_seq_length()
|
||||||
# Some generation methods already pass only the last input ID
|
past_length = past_key_values.seen_tokens
|
||||||
if input_ids.shape[1] > past_length:
|
|
||||||
remove_prefix_length = past_length
|
|
||||||
else:
|
else:
|
||||||
# Default to old behavior: keep only final ID
|
cache_length = past_length = past_key_values[0][0].shape[2]
|
||||||
remove_prefix_length = input_ids.shape[1] - 1
|
|
||||||
|
|
||||||
input_ids = input_ids[:, remove_prefix_length:]
|
# Keep only the unprocessed tokens:
|
||||||
|
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
||||||
|
# some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as
|
||||||
|
# input)
|
||||||
|
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
|
||||||
|
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
||||||
|
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
|
||||||
|
# input_ids based on the past_length.
|
||||||
|
elif past_length < input_ids.shape[1]:
|
||||||
|
input_ids = input_ids[:, past_length:]
|
||||||
|
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
|
||||||
|
|
||||||
|
# If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the
|
||||||
|
# older attention values, as their corresponding values are not part of the input.
|
||||||
|
if cache_length < past_length and attention_mask is not None:
|
||||||
|
attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :]
|
||||||
|
|
||||||
position_ids = kwargs.get("position_ids", None)
|
position_ids = kwargs.get("position_ids", None)
|
||||||
if attention_mask is not None and position_ids is None:
|
if attention_mask is not None and position_ids is None:
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ from torch import nn
|
|||||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||||
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
|
from ...cache_utils import Cache, DynamicCache
|
||||||
from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
||||||
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
|
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
|
||||||
from ...modeling_utils import PreTrainedModel
|
from ...modeling_utils import PreTrainedModel
|
||||||
@@ -195,9 +196,17 @@ class MistralAttention(nn.Module):
|
|||||||
and "Generating Long Sequences with Sparse Transformers".
|
and "Generating Long Sequences with Sparse Transformers".
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config: MistralConfig):
|
def __init__(self, config: MistralConfig, layer_idx: Optional[int] = None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.layer_idx = layer_idx
|
||||||
|
if layer_idx is None:
|
||||||
|
logger.warning_once(
|
||||||
|
f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
|
||||||
|
"to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
|
||||||
|
"when creating this class."
|
||||||
|
)
|
||||||
|
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
self.num_heads = config.num_attention_heads
|
self.num_heads = config.num_attention_heads
|
||||||
self.head_dim = self.hidden_size // self.num_heads
|
self.head_dim = self.hidden_size // self.num_heads
|
||||||
@@ -232,7 +241,7 @@ class MistralAttention(nn.Module):
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
past_key_value: Optional[Cache] = None,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@@ -253,16 +262,19 @@ class MistralAttention(nn.Module):
|
|||||||
|
|
||||||
kv_seq_len = key_states.shape[-2]
|
kv_seq_len = key_states.shape[-2]
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
kv_seq_len += past_key_value[0].shape[-2]
|
if self.layer_idx is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
||||||
|
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
||||||
|
"with a layer index."
|
||||||
|
)
|
||||||
|
kv_seq_len += past_key_value.get_seq_length(self.layer_idx)
|
||||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||||
|
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
# reuse k, v, self_attention
|
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
|
||||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
|
||||||
|
|
||||||
past_key_value = (key_states, value_states) if use_cache else None
|
|
||||||
|
|
||||||
# repeat k/v heads if n_kv_heads < n_heads
|
# repeat k/v heads if n_kv_heads < n_heads
|
||||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||||
@@ -327,7 +339,7 @@ class MistralFlashAttention2(MistralAttention):
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
past_key_value: Optional[Cache] = None,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@@ -351,7 +363,7 @@ class MistralFlashAttention2(MistralAttention):
|
|||||||
|
|
||||||
kv_seq_len = key_states.shape[-2]
|
kv_seq_len = key_states.shape[-2]
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
kv_seq_len += past_key_value[0].shape[-2]
|
kv_seq_len += past_key_value.get_seq_length(self.layer_idx)
|
||||||
|
|
||||||
# Because the input can be padded, the absolute sequence length depends on the max position id.
|
# Because the input can be padded, the absolute sequence length depends on the max position id.
|
||||||
rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
|
rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
|
||||||
@@ -394,10 +406,8 @@ class MistralFlashAttention2(MistralAttention):
|
|||||||
attention_mask = attention_mask[:, slicing_tokens:]
|
attention_mask = attention_mask[:, slicing_tokens:]
|
||||||
attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
|
attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
|
||||||
|
|
||||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
|
||||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||||
|
|
||||||
past_key_value = (key_states, value_states) if use_cache else None
|
|
||||||
|
|
||||||
# repeat k/v heads if n_kv_heads < n_heads
|
# repeat k/v heads if n_kv_heads < n_heads
|
||||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||||
@@ -592,13 +602,13 @@ class MistralFlashAttention2(MistralAttention):
|
|||||||
|
|
||||||
|
|
||||||
class MistralDecoderLayer(nn.Module):
|
class MistralDecoderLayer(nn.Module):
|
||||||
def __init__(self, config: MistralConfig):
|
def __init__(self, config: MistralConfig, layer_idx: int):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
self.self_attn = (
|
self.self_attn = (
|
||||||
MistralAttention(config=config)
|
MistralAttention(config=config, layer_idx=layer_idx)
|
||||||
if not getattr(config, "_flash_attn_2_enabled", False)
|
if not getattr(config, "_flash_attn_2_enabled", False)
|
||||||
else MistralFlashAttention2(config)
|
else MistralFlashAttention2(config, layer_idx=layer_idx)
|
||||||
)
|
)
|
||||||
self.mlp = MistralMLP(config)
|
self.mlp = MistralMLP(config)
|
||||||
self.input_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.input_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
@@ -692,6 +702,7 @@ class MistralPreTrainedModel(PreTrainedModel):
|
|||||||
_no_split_modules = ["MistralDecoderLayer"]
|
_no_split_modules = ["MistralDecoderLayer"]
|
||||||
_skip_keys_device_placement = "past_key_values"
|
_skip_keys_device_placement = "past_key_values"
|
||||||
_supports_flash_attn_2 = True
|
_supports_flash_attn_2 = True
|
||||||
|
_supports_cache_class = True
|
||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
std = self.config.initializer_range
|
std = self.config.initializer_range
|
||||||
@@ -740,17 +751,23 @@ MISTRAL_INPUTS_DOCSTRING = r"""
|
|||||||
config.n_positions - 1]`.
|
config.n_positions - 1]`.
|
||||||
|
|
||||||
[What are position IDs?](../glossary#position-ids)
|
[What are position IDs?](../glossary#position-ids)
|
||||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
|
||||||
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
||||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
|
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
|
||||||
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
|
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
|
||||||
|
|
||||||
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
Two formats are allowed:
|
||||||
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
- a [`~cache_utils.Cache`] instance;
|
||||||
|
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
|
||||||
|
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
|
||||||
|
cache format.
|
||||||
|
|
||||||
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
|
The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
|
||||||
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
|
legacy cache format will be returned.
|
||||||
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
|
||||||
|
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
|
||||||
|
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
|
||||||
|
of shape `(batch_size, sequence_length)`.
|
||||||
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
||||||
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
||||||
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
||||||
@@ -787,7 +804,9 @@ class MistralModel(MistralPreTrainedModel):
|
|||||||
self.vocab_size = config.vocab_size
|
self.vocab_size = config.vocab_size
|
||||||
|
|
||||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
||||||
self.layers = nn.ModuleList([MistralDecoderLayer(config) for _ in range(config.num_hidden_layers)])
|
self.layers = nn.ModuleList(
|
||||||
|
[MistralDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
||||||
|
)
|
||||||
self.norm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.norm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
|
||||||
self.gradient_checkpointing = False
|
self.gradient_checkpointing = False
|
||||||
@@ -834,8 +853,11 @@ class MistralModel(MistralPreTrainedModel):
|
|||||||
seq_length_with_past = seq_length
|
seq_length_with_past = seq_length
|
||||||
past_key_values_length = 0
|
past_key_values_length = 0
|
||||||
|
|
||||||
if past_key_values is not None:
|
if use_cache:
|
||||||
past_key_values_length = past_key_values[0][0].shape[2]
|
use_legacy_cache = not isinstance(past_key_values, Cache)
|
||||||
|
if use_legacy_cache:
|
||||||
|
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||||
|
past_key_values_length = past_key_values.get_seq_length()
|
||||||
seq_length_with_past = seq_length_with_past + past_key_values_length
|
seq_length_with_past = seq_length_with_past + past_key_values_length
|
||||||
|
|
||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
@@ -889,21 +911,19 @@ class MistralModel(MistralPreTrainedModel):
|
|||||||
# decoder layers
|
# decoder layers
|
||||||
all_hidden_states = () if output_hidden_states else None
|
all_hidden_states = () if output_hidden_states else None
|
||||||
all_self_attns = () if output_attentions else None
|
all_self_attns = () if output_attentions else None
|
||||||
next_decoder_cache = () if use_cache else None
|
next_decoder_cache = None
|
||||||
|
|
||||||
for idx, decoder_layer in enumerate(self.layers):
|
for decoder_layer in self.layers:
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states += (hidden_states,)
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
if self.gradient_checkpointing and self.training:
|
||||||
layer_outputs = self._gradient_checkpointing_func(
|
layer_outputs = self._gradient_checkpointing_func(
|
||||||
decoder_layer.__call__,
|
decoder_layer.__call__,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
position_ids,
|
position_ids,
|
||||||
past_key_value,
|
past_key_values,
|
||||||
output_attentions,
|
output_attentions,
|
||||||
use_cache,
|
use_cache,
|
||||||
)
|
)
|
||||||
@@ -912,7 +932,7 @@ class MistralModel(MistralPreTrainedModel):
|
|||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
past_key_value=past_key_value,
|
past_key_value=past_key_values,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
)
|
)
|
||||||
@@ -920,7 +940,7 @@ class MistralModel(MistralPreTrainedModel):
|
|||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
if use_cache:
|
if use_cache:
|
||||||
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
||||||
|
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
all_self_attns += (layer_outputs[1],)
|
all_self_attns += (layer_outputs[1],)
|
||||||
@@ -931,7 +951,10 @@ class MistralModel(MistralPreTrainedModel):
|
|||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states += (hidden_states,)
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
next_cache = next_decoder_cache if use_cache else None
|
next_cache = None
|
||||||
|
if use_cache:
|
||||||
|
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
||||||
return BaseModelOutputWithPast(
|
return BaseModelOutputWithPast(
|
||||||
@@ -1065,17 +1088,29 @@ class MistralForCausalLM(MistralPreTrainedModel):
|
|||||||
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
|
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
|
||||||
):
|
):
|
||||||
# Omit tokens covered by past_key_values
|
# Omit tokens covered by past_key_values
|
||||||
if past_key_values:
|
if past_key_values is not None:
|
||||||
past_length = past_key_values[0][0].shape[2]
|
if isinstance(past_key_values, Cache):
|
||||||
|
cache_length = past_key_values.get_seq_length()
|
||||||
# Some generation methods already pass only the last input ID
|
past_length = past_key_values.seen_tokens
|
||||||
if input_ids.shape[1] > past_length:
|
|
||||||
remove_prefix_length = past_length
|
|
||||||
else:
|
else:
|
||||||
# Default to old behavior: keep only final ID
|
cache_length = past_length = past_key_values[0][0].shape[2]
|
||||||
remove_prefix_length = input_ids.shape[1] - 1
|
|
||||||
|
|
||||||
input_ids = input_ids[:, remove_prefix_length:]
|
# Keep only the unprocessed tokens:
|
||||||
|
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
||||||
|
# some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as
|
||||||
|
# input)
|
||||||
|
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
|
||||||
|
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
||||||
|
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
|
||||||
|
# input_ids based on the past_length.
|
||||||
|
elif past_length < input_ids.shape[1]:
|
||||||
|
input_ids = input_ids[:, past_length:]
|
||||||
|
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
|
||||||
|
|
||||||
|
# If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the
|
||||||
|
# older attention values, as their corresponding values are not part of the input.
|
||||||
|
if cache_length < past_length and attention_mask is not None:
|
||||||
|
attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :]
|
||||||
|
|
||||||
position_ids = kwargs.get("position_ids", None)
|
position_ids = kwargs.get("position_ids", None)
|
||||||
if attention_mask is not None and position_ids is None:
|
if attention_mask is not None and position_ids is None:
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ from torch import nn
|
|||||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||||
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
|
from ...cache_utils import Cache, DynamicCache
|
||||||
from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
||||||
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
|
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
|
||||||
from ...modeling_utils import PreTrainedModel
|
from ...modeling_utils import PreTrainedModel
|
||||||
@@ -178,9 +179,17 @@ class PersimmonMLP(nn.Module):
|
|||||||
class PersimmonAttention(nn.Module):
|
class PersimmonAttention(nn.Module):
|
||||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||||
|
|
||||||
def __init__(self, config: PersimmonConfig):
|
def __init__(self, config: PersimmonConfig, layer_idx: Optional[int] = None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.layer_idx = layer_idx
|
||||||
|
if layer_idx is None:
|
||||||
|
logger.warning_once(
|
||||||
|
f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
|
||||||
|
"to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
|
||||||
|
"when creating this class."
|
||||||
|
)
|
||||||
|
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
self.num_heads = config.num_attention_heads
|
self.num_heads = config.num_attention_heads
|
||||||
self.head_dim = self.hidden_size // self.num_heads
|
self.head_dim = self.hidden_size // self.num_heads
|
||||||
@@ -257,7 +266,7 @@ class PersimmonAttention(nn.Module):
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
past_key_value: Optional[Cache] = None,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
@@ -280,7 +289,13 @@ class PersimmonAttention(nn.Module):
|
|||||||
|
|
||||||
kv_seq_len = key_states.shape[-2]
|
kv_seq_len = key_states.shape[-2]
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
kv_seq_len += past_key_value[0].shape[-2]
|
if self.layer_idx is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
||||||
|
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
||||||
|
"with a layer index."
|
||||||
|
)
|
||||||
|
kv_seq_len += past_key_value.get_seq_length(self.layer_idx)
|
||||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||||
|
|
||||||
# Partial rotary embedding
|
# Partial rotary embedding
|
||||||
@@ -300,11 +315,9 @@ class PersimmonAttention(nn.Module):
|
|||||||
key_states = torch.cat((key_rot, key_pass), dim=-1)
|
key_states = torch.cat((key_rot, key_pass), dim=-1)
|
||||||
|
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
# reuse k, v, self_attention
|
# Specific to RoPE models with partial rotation
|
||||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
cache_kwargs = {"sin": sin, "cos": cos, "partial_rotation_size": self.rotary_emb.dim}
|
||||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||||
|
|
||||||
past_key_value = (key_states, value_states) if use_cache else None
|
|
||||||
|
|
||||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||||
|
|
||||||
@@ -345,10 +358,10 @@ class PersimmonAttention(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class PersimmonDecoderLayer(nn.Module):
|
class PersimmonDecoderLayer(nn.Module):
|
||||||
def __init__(self, config: PersimmonConfig):
|
def __init__(self, config: PersimmonConfig, layer_idx: int):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
self.self_attn = PersimmonAttention(config=config)
|
self.self_attn = PersimmonAttention(config=config, layer_idx=layer_idx)
|
||||||
self.mlp = PersimmonMLP(config)
|
self.mlp = PersimmonMLP(config)
|
||||||
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
@@ -444,6 +457,7 @@ class PersimmonPreTrainedModel(PreTrainedModel):
|
|||||||
supports_gradient_checkpointing = True
|
supports_gradient_checkpointing = True
|
||||||
_no_split_modules = ["PersimmonDecoderLayer"]
|
_no_split_modules = ["PersimmonDecoderLayer"]
|
||||||
_skip_keys_device_placement = "past_key_values"
|
_skip_keys_device_placement = "past_key_values"
|
||||||
|
_supports_cache_class = True
|
||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
std = self.config.initializer_range
|
std = self.config.initializer_range
|
||||||
@@ -492,17 +506,23 @@ PERSIMMON_INPUTS_DOCSTRING = r"""
|
|||||||
config.n_positions - 1]`.
|
config.n_positions - 1]`.
|
||||||
|
|
||||||
[What are position IDs?](../glossary#position-ids)
|
[What are position IDs?](../glossary#position-ids)
|
||||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
|
||||||
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
||||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
|
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
|
||||||
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
|
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
|
||||||
|
|
||||||
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
Two formats are allowed:
|
||||||
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
- a [`~cache_utils.Cache`] instance;
|
||||||
|
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
|
||||||
|
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
|
||||||
|
cache format.
|
||||||
|
|
||||||
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
|
The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
|
||||||
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
|
legacy cache format will be returned.
|
||||||
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
|
||||||
|
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
|
||||||
|
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
|
||||||
|
of shape `(batch_size, sequence_length)`.
|
||||||
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
||||||
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
||||||
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
||||||
@@ -539,7 +559,9 @@ class PersimmonModel(PersimmonPreTrainedModel):
|
|||||||
self.vocab_size = config.vocab_size
|
self.vocab_size = config.vocab_size
|
||||||
|
|
||||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
||||||
self.layers = nn.ModuleList([PersimmonDecoderLayer(config) for _ in range(config.num_hidden_layers)])
|
self.layers = nn.ModuleList(
|
||||||
|
[PersimmonDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
||||||
|
)
|
||||||
self.final_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
self.final_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
|
|
||||||
self.gradient_checkpointing = False
|
self.gradient_checkpointing = False
|
||||||
@@ -586,8 +608,11 @@ class PersimmonModel(PersimmonPreTrainedModel):
|
|||||||
seq_length_with_past = seq_length
|
seq_length_with_past = seq_length
|
||||||
past_key_values_length = 0
|
past_key_values_length = 0
|
||||||
|
|
||||||
if past_key_values is not None:
|
if use_cache:
|
||||||
past_key_values_length = past_key_values[0][0].shape[2]
|
use_legacy_cache = not isinstance(past_key_values, Cache)
|
||||||
|
if use_legacy_cache:
|
||||||
|
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||||
|
past_key_values_length = past_key_values.get_seq_length()
|
||||||
seq_length_with_past = seq_length_with_past + past_key_values_length
|
seq_length_with_past = seq_length_with_past + past_key_values_length
|
||||||
|
|
||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
@@ -620,21 +645,19 @@ class PersimmonModel(PersimmonPreTrainedModel):
|
|||||||
# decoder layers
|
# decoder layers
|
||||||
all_hidden_states = () if output_hidden_states else None
|
all_hidden_states = () if output_hidden_states else None
|
||||||
all_self_attns = () if output_attentions else None
|
all_self_attns = () if output_attentions else None
|
||||||
next_decoder_cache = () if use_cache else None
|
next_decoder_cache = None
|
||||||
|
|
||||||
for idx, decoder_layer in enumerate(self.layers):
|
for decoder_layer in self.layers:
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states += (hidden_states,)
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
if self.gradient_checkpointing and self.training:
|
||||||
layer_outputs = self._gradient_checkpointing_func(
|
layer_outputs = self._gradient_checkpointing_func(
|
||||||
decoder_layer.__call__,
|
decoder_layer.__call__,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
position_ids,
|
position_ids,
|
||||||
past_key_value,
|
past_key_values,
|
||||||
output_attentions,
|
output_attentions,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -642,7 +665,7 @@ class PersimmonModel(PersimmonPreTrainedModel):
|
|||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
past_key_value=past_key_value,
|
past_key_value=past_key_values,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
)
|
)
|
||||||
@@ -650,7 +673,7 @@ class PersimmonModel(PersimmonPreTrainedModel):
|
|||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
if use_cache:
|
if use_cache:
|
||||||
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
||||||
|
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
all_self_attns += (layer_outputs[1],)
|
all_self_attns += (layer_outputs[1],)
|
||||||
@@ -661,7 +684,10 @@ class PersimmonModel(PersimmonPreTrainedModel):
|
|||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states += (hidden_states,)
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
next_cache = next_decoder_cache if use_cache else None
|
next_cache = None
|
||||||
|
if use_cache:
|
||||||
|
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
||||||
return BaseModelOutputWithPast(
|
return BaseModelOutputWithPast(
|
||||||
@@ -802,16 +828,28 @@ class PersimmonForCausalLM(PersimmonPreTrainedModel):
|
|||||||
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
|
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
|
||||||
):
|
):
|
||||||
if past_key_values is not None:
|
if past_key_values is not None:
|
||||||
past_length = past_key_values[0][0].shape[2]
|
if isinstance(past_key_values, Cache):
|
||||||
|
cache_length = past_key_values.get_seq_length()
|
||||||
# Some generation methods already pass only the last input ID
|
past_length = past_key_values.seen_tokens
|
||||||
if input_ids.shape[1] > past_length:
|
|
||||||
remove_prefix_length = past_length
|
|
||||||
else:
|
else:
|
||||||
# Default to old behavior: keep only final ID
|
cache_length = past_length = past_key_values[0][0].shape[2]
|
||||||
remove_prefix_length = input_ids.shape[1] - 1
|
|
||||||
|
|
||||||
input_ids = input_ids[:, remove_prefix_length:]
|
# Keep only the unprocessed tokens:
|
||||||
|
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
||||||
|
# some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as
|
||||||
|
# input)
|
||||||
|
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
|
||||||
|
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
||||||
|
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
|
||||||
|
# input_ids based on the past_length.
|
||||||
|
elif past_length < input_ids.shape[1]:
|
||||||
|
input_ids = input_ids[:, past_length:]
|
||||||
|
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
|
||||||
|
|
||||||
|
# If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the
|
||||||
|
# older attention values, as their corresponding values are not part of the input.
|
||||||
|
if cache_length < past_length and attention_mask is not None:
|
||||||
|
attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :]
|
||||||
|
|
||||||
position_ids = kwargs.get("position_ids", None)
|
position_ids = kwargs.get("position_ids", None)
|
||||||
if attention_mask is not None and position_ids is None:
|
if attention_mask is not None and position_ids is None:
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ from torch import nn
|
|||||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||||
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
|
from ...cache_utils import Cache, DynamicCache
|
||||||
from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
||||||
from ...modeling_outputs import (
|
from ...modeling_outputs import (
|
||||||
BaseModelOutputWithPast,
|
BaseModelOutputWithPast,
|
||||||
@@ -217,9 +218,17 @@ class PhiMLP(nn.Module):
|
|||||||
class PhiAttention(nn.Module):
|
class PhiAttention(nn.Module):
|
||||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||||
|
|
||||||
def __init__(self, config: PhiConfig):
|
def __init__(self, config: PhiConfig, layer_idx: Optional[int] = None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.layer_idx = layer_idx
|
||||||
|
if layer_idx is None:
|
||||||
|
logger.warning_once(
|
||||||
|
f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
|
||||||
|
"to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
|
||||||
|
"when creating this class."
|
||||||
|
)
|
||||||
|
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
self.num_heads = config.num_attention_heads
|
self.num_heads = config.num_attention_heads
|
||||||
self.head_dim = self.hidden_size // self.num_heads
|
self.head_dim = self.hidden_size // self.num_heads
|
||||||
@@ -296,7 +305,7 @@ class PhiAttention(nn.Module):
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
past_key_value: Optional[Cache] = None,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
@@ -319,7 +328,13 @@ class PhiAttention(nn.Module):
|
|||||||
|
|
||||||
kv_seq_len = key_states.shape[-2]
|
kv_seq_len = key_states.shape[-2]
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
kv_seq_len += past_key_value[0].shape[-2]
|
if self.layer_idx is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
||||||
|
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
||||||
|
"with a layer index."
|
||||||
|
)
|
||||||
|
kv_seq_len += past_key_value.get_seq_length(self.layer_idx)
|
||||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||||
|
|
||||||
# Partial rotary embedding
|
# Partial rotary embedding
|
||||||
@@ -339,11 +354,9 @@ class PhiAttention(nn.Module):
|
|||||||
key_states = torch.cat((key_rot, key_pass), dim=-1)
|
key_states = torch.cat((key_rot, key_pass), dim=-1)
|
||||||
|
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
# reuse k, v, self_attention
|
# Specific to RoPE models with partial rotation
|
||||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
cache_kwargs = {"sin": sin, "cos": cos, "partial_rotation_size": self.rotary_emb.dim}
|
||||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||||
|
|
||||||
past_key_value = (key_states, value_states) if use_cache else None
|
|
||||||
|
|
||||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||||
|
|
||||||
@@ -404,7 +417,7 @@ class PhiFlashAttention2(PhiAttention):
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
past_key_value: Optional[Cache] = None,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
@@ -431,7 +444,7 @@ class PhiFlashAttention2(PhiAttention):
|
|||||||
|
|
||||||
kv_seq_len = key_states.shape[-2]
|
kv_seq_len = key_states.shape[-2]
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
kv_seq_len += past_key_value[0].shape[-2]
|
kv_seq_len += past_key_value.get_seq_length(self.layer_idx)
|
||||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||||
|
|
||||||
# Partial rotary embedding
|
# Partial rotary embedding
|
||||||
@@ -451,11 +464,8 @@ class PhiFlashAttention2(PhiAttention):
|
|||||||
key_states = torch.cat((key_rot, key_pass), dim=-1)
|
key_states = torch.cat((key_rot, key_pass), dim=-1)
|
||||||
|
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
# reuse k, v, self_attention
|
cache_kwargs = {"sin": sin, "cos": cos, "partial_rotation_size": self.rotary_emb.dim}
|
||||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
|
||||||
|
|
||||||
past_key_value = (key_states, value_states) if use_cache else None
|
|
||||||
|
|
||||||
tgt_len = key_states.shape[2]
|
tgt_len = key_states.shape[2]
|
||||||
|
|
||||||
@@ -603,12 +613,12 @@ class PhiFlashAttention2(PhiAttention):
|
|||||||
|
|
||||||
|
|
||||||
class PhiDecoderLayer(nn.Module):
|
class PhiDecoderLayer(nn.Module):
|
||||||
def __init__(self, config: PhiConfig):
|
def __init__(self, config: PhiConfig, layer_idx: int):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.self_attn = (
|
self.self_attn = (
|
||||||
PhiAttention(config=config)
|
PhiAttention(config=config, layer_idx=layer_idx)
|
||||||
if not getattr(config, "_flash_attn_2_enabled", False)
|
if not getattr(config, "_flash_attn_2_enabled", False)
|
||||||
else PhiFlashAttention2(config=config)
|
else PhiFlashAttention2(config=config, layer_idx=layer_idx)
|
||||||
)
|
)
|
||||||
self.mlp = PhiMLP(config)
|
self.mlp = PhiMLP(config)
|
||||||
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
@@ -696,6 +706,7 @@ class PhiPreTrainedModel(PreTrainedModel):
|
|||||||
supports_gradient_checkpointing = True
|
supports_gradient_checkpointing = True
|
||||||
_skip_keys_device_placement = "past_key_values"
|
_skip_keys_device_placement = "past_key_values"
|
||||||
_supports_flash_attn_2 = True
|
_supports_flash_attn_2 = True
|
||||||
|
_supports_cache_class = True
|
||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
std = self.config.initializer_range
|
std = self.config.initializer_range
|
||||||
@@ -744,17 +755,23 @@ PHI_INPUTS_DOCSTRING = r"""
|
|||||||
config.n_positions - 1]`.
|
config.n_positions - 1]`.
|
||||||
|
|
||||||
[What are position IDs?](../glossary#position-ids)
|
[What are position IDs?](../glossary#position-ids)
|
||||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
|
||||||
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
||||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
|
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
|
||||||
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
|
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
|
||||||
|
|
||||||
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
Two formats are allowed:
|
||||||
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
- a [`~cache_utils.Cache`] instance;
|
||||||
|
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
|
||||||
|
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
|
||||||
|
cache format.
|
||||||
|
|
||||||
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
|
The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
|
||||||
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
|
legacy cache format will be returned.
|
||||||
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
|
||||||
|
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
|
||||||
|
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
|
||||||
|
of shape `(batch_size, sequence_length)`.
|
||||||
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
||||||
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
||||||
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
||||||
@@ -792,7 +809,9 @@ class PhiModel(PhiPreTrainedModel):
|
|||||||
|
|
||||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
||||||
self.embed_dropout = nn.Dropout(config.embd_pdrop)
|
self.embed_dropout = nn.Dropout(config.embd_pdrop)
|
||||||
self.layers = nn.ModuleList([PhiDecoderLayer(config) for _ in range(config.num_hidden_layers)])
|
self.layers = nn.ModuleList(
|
||||||
|
[PhiDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
||||||
|
)
|
||||||
self.final_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
self.final_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
|
|
||||||
self.gradient_checkpointing = False
|
self.gradient_checkpointing = False
|
||||||
@@ -839,8 +858,11 @@ class PhiModel(PhiPreTrainedModel):
|
|||||||
seq_length_with_past = seq_length
|
seq_length_with_past = seq_length
|
||||||
past_key_values_length = 0
|
past_key_values_length = 0
|
||||||
|
|
||||||
if past_key_values is not None:
|
if use_cache:
|
||||||
past_key_values_length = past_key_values[0][0].shape[2]
|
use_legacy_cache = not isinstance(past_key_values, Cache)
|
||||||
|
if use_legacy_cache:
|
||||||
|
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||||
|
past_key_values_length = past_key_values.get_seq_length()
|
||||||
seq_length_with_past = seq_length_with_past + past_key_values_length
|
seq_length_with_past = seq_length_with_past + past_key_values_length
|
||||||
|
|
||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
@@ -877,21 +899,19 @@ class PhiModel(PhiPreTrainedModel):
|
|||||||
# decoder layers
|
# decoder layers
|
||||||
all_hidden_states = () if output_hidden_states else None
|
all_hidden_states = () if output_hidden_states else None
|
||||||
all_self_attns = () if output_attentions else None
|
all_self_attns = () if output_attentions else None
|
||||||
next_decoder_cache = () if use_cache else None
|
next_decoder_cache = None
|
||||||
|
|
||||||
for idx, decoder_layer in enumerate(self.layers):
|
for decoder_layer in self.layers:
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states += (hidden_states,)
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
if self.gradient_checkpointing and self.training:
|
||||||
layer_outputs = self._gradient_checkpointing_func(
|
layer_outputs = self._gradient_checkpointing_func(
|
||||||
decoder_layer.__call__,
|
decoder_layer.__call__,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
position_ids,
|
position_ids,
|
||||||
past_key_value,
|
past_key_values,
|
||||||
output_attentions,
|
output_attentions,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -899,7 +919,7 @@ class PhiModel(PhiPreTrainedModel):
|
|||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
past_key_value=past_key_value,
|
past_key_value=past_key_values,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
)
|
)
|
||||||
@@ -907,7 +927,7 @@ class PhiModel(PhiPreTrainedModel):
|
|||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
if use_cache:
|
if use_cache:
|
||||||
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
||||||
|
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
all_self_attns += (layer_outputs[1],)
|
all_self_attns += (layer_outputs[1],)
|
||||||
@@ -918,7 +938,9 @@ class PhiModel(PhiPreTrainedModel):
|
|||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states += (hidden_states,)
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
next_cache = next_decoder_cache if use_cache else None
|
next_cache = None
|
||||||
|
if use_cache:
|
||||||
|
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
||||||
return BaseModelOutputWithPast(
|
return BaseModelOutputWithPast(
|
||||||
@@ -1060,16 +1082,28 @@ class PhiForCausalLM(PhiPreTrainedModel):
|
|||||||
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
|
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
|
||||||
):
|
):
|
||||||
if past_key_values is not None:
|
if past_key_values is not None:
|
||||||
past_length = past_key_values[0][0].shape[2]
|
if isinstance(past_key_values, Cache):
|
||||||
|
cache_length = past_key_values.get_seq_length()
|
||||||
# Some generation methods already pass only the last input ID
|
past_length = past_key_values.seen_tokens
|
||||||
if input_ids.shape[1] > past_length:
|
|
||||||
remove_prefix_length = past_length
|
|
||||||
else:
|
else:
|
||||||
# Default to old behavior: keep only final ID
|
cache_length = past_length = past_key_values[0][0].shape[2]
|
||||||
remove_prefix_length = input_ids.shape[1] - 1
|
|
||||||
|
|
||||||
input_ids = input_ids[:, remove_prefix_length:]
|
# Keep only the unprocessed tokens:
|
||||||
|
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
||||||
|
# some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as
|
||||||
|
# input)
|
||||||
|
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
|
||||||
|
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
||||||
|
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
|
||||||
|
# input_ids based on the past_length.
|
||||||
|
elif past_length < input_ids.shape[1]:
|
||||||
|
input_ids = input_ids[:, past_length:]
|
||||||
|
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
|
||||||
|
|
||||||
|
# If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the
|
||||||
|
# older attention values, as their corresponding values are not part of the input.
|
||||||
|
if cache_length < past_length and attention_mask is not None:
|
||||||
|
attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :]
|
||||||
|
|
||||||
position_ids = kwargs.get("position_ids", None)
|
position_ids = kwargs.get("position_ids", None)
|
||||||
if attention_mask is not None and position_ids is None:
|
if attention_mask is not None and position_ids is None:
|
||||||
|
|||||||
@@ -16,6 +16,27 @@ class PyTorchBenchmarkArguments(metaclass=DummyObject):
|
|||||||
requires_backends(self, ["torch"])
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class Cache(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class DynamicCache(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class SinkCache(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
class GlueDataset(metaclass=DummyObject):
|
class GlueDataset(metaclass=DummyObject):
|
||||||
_backends = ["torch"]
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
|||||||
@@ -20,8 +20,9 @@ import unittest
|
|||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from parameterized import parameterized
|
||||||
|
|
||||||
from transformers import is_torch_available, pipeline
|
from transformers import is_torch_available, pipeline, set_seed
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
is_flaky,
|
is_flaky,
|
||||||
require_accelerate,
|
require_accelerate,
|
||||||
@@ -53,6 +54,7 @@ if is_torch_available():
|
|||||||
SpeechEncoderDecoderModel,
|
SpeechEncoderDecoderModel,
|
||||||
top_k_top_p_filtering,
|
top_k_top_p_filtering,
|
||||||
)
|
)
|
||||||
|
from transformers.cache_utils import DynamicCache
|
||||||
from transformers.generation import (
|
from transformers.generation import (
|
||||||
BeamSampleDecoderOnlyOutput,
|
BeamSampleDecoderOnlyOutput,
|
||||||
BeamSampleEncoderDecoderOutput,
|
BeamSampleEncoderDecoderOutput,
|
||||||
@@ -1904,6 +1906,66 @@ class GenerationTesterMixin:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@parameterized.expand([(1, False), (1, True), (4, False)])
|
||||||
|
def test_new_cache_format(self, num_beams, do_sample):
|
||||||
|
# Tests that generating with the new format is exactly the same as the legacy one (for models that support it).
|
||||||
|
# 👉 tests with and without beam search so that we can test with and without cache reordering.
|
||||||
|
# 👉 tests with and without sampling so we can cover the most common use cases.
|
||||||
|
for model_class in self.all_generative_model_classes:
|
||||||
|
if not model_class._supports_cache_class:
|
||||||
|
self.skipTest("This model does not support the new cache format")
|
||||||
|
|
||||||
|
config, input_ids, attention_mask, _ = self._get_input_ids_and_config()
|
||||||
|
config.use_cache = True
|
||||||
|
config.is_decoder = True
|
||||||
|
|
||||||
|
model = model_class(config).to(torch_device).eval()
|
||||||
|
generation_kwargs = {
|
||||||
|
"max_new_tokens": 5,
|
||||||
|
"do_sample": do_sample,
|
||||||
|
"num_beams": num_beams,
|
||||||
|
"num_return_sequences": num_beams,
|
||||||
|
"return_dict_in_generate": True, # Required to return `past_key_values`
|
||||||
|
}
|
||||||
|
|
||||||
|
# Sets seed before calling `generate` for the case with do_sample=True
|
||||||
|
seed = torch.randint(0, 1000000, (1,)).item()
|
||||||
|
set_seed(seed)
|
||||||
|
legacy_results = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
|
||||||
|
set_seed(seed)
|
||||||
|
new_results = model.generate(
|
||||||
|
input_ids, attention_mask=attention_mask, past_key_values=DynamicCache(), **generation_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
# The two sets of generated sequences must match, despite the cache format between forward passes being
|
||||||
|
# different
|
||||||
|
self.assertListEqual(legacy_results.sequences.tolist(), new_results.sequences.tolist())
|
||||||
|
self.assertTrue(isinstance(legacy_results.past_key_values, tuple))
|
||||||
|
self.assertTrue(isinstance(new_results.past_key_values, DynamicCache))
|
||||||
|
|
||||||
|
# The contents of the two caches, when converted to the same format (in both directions!), must match
|
||||||
|
legacy_cache = legacy_results.past_key_values
|
||||||
|
new_cache_converted = new_results.past_key_values.to_legacy_cache()
|
||||||
|
for layer_idx in range(len(legacy_cache)):
|
||||||
|
for kv_idx in range(len(legacy_cache[layer_idx])):
|
||||||
|
self.assertTrue(
|
||||||
|
torch.allclose(
|
||||||
|
legacy_cache[layer_idx][kv_idx],
|
||||||
|
new_cache_converted[layer_idx][kv_idx],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
new_cache = new_results.past_key_values
|
||||||
|
legacy_cache_converted = DynamicCache.from_legacy_cache(legacy_results.past_key_values)
|
||||||
|
for layer_idx in range(len(new_cache)):
|
||||||
|
for kv_idx in range(len(new_cache[layer_idx])):
|
||||||
|
self.assertTrue(
|
||||||
|
torch.allclose(
|
||||||
|
new_cache[layer_idx][kv_idx],
|
||||||
|
legacy_cache_converted[layer_idx][kv_idx],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1):
|
def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1):
|
||||||
batch_size, seq_length = input_ids.shape
|
batch_size, seq_length = input_ids.shape
|
||||||
num_sequences_in_output = batch_size * num_return_sequences
|
num_sequences_in_output = batch_size * num_return_sequences
|
||||||
|
|||||||
189
tests/test_cache_utils.py
Normal file
189
tests/test_cache_utils.py
Normal file
@@ -0,0 +1,189 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2023 HuggingFace Inc.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from transformers import set_seed
|
||||||
|
from transformers.testing_utils import is_torch_available, require_auto_gptq, require_torch, require_torch_gpu, slow
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache, LlamaForCausalLM, SinkCache
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
class CacheTest(unittest.TestCase):
|
||||||
|
def test_cache_equivalence(self):
|
||||||
|
"""Tests that we can convert back and forth between the legacy cache format and DynamicCache"""
|
||||||
|
legacy_cache = ()
|
||||||
|
new_cache = DynamicCache()
|
||||||
|
|
||||||
|
# Creates a new cache with 10 layers in both formats
|
||||||
|
for layer_idx in range(10):
|
||||||
|
new_key = torch.rand((2, 4, 8, 16))
|
||||||
|
new_value = torch.rand((2, 4, 8, 16))
|
||||||
|
new_cache.update(new_key, new_value, layer_idx)
|
||||||
|
legacy_cache += ((new_key, new_value),)
|
||||||
|
|
||||||
|
# Sanity check 1: they must have the same shapes
|
||||||
|
self.assertTrue(len(legacy_cache), len(new_cache))
|
||||||
|
for layer_idx in range(10):
|
||||||
|
self.assertTrue(len(legacy_cache[layer_idx]), len(legacy_cache[layer_idx]))
|
||||||
|
for key_value_idx in range(2):
|
||||||
|
self.assertTrue(
|
||||||
|
legacy_cache[layer_idx][key_value_idx].shape == new_cache[layer_idx][key_value_idx].shape
|
||||||
|
)
|
||||||
|
|
||||||
|
# Sanity check 2: we can get the sequence length in multiple ways with DynamicCache, and they return the
|
||||||
|
# expected value
|
||||||
|
self.assertTrue(legacy_cache[0][0].shape[-2] == new_cache[0][0].shape[-2] == new_cache.get_seq_length() == 8)
|
||||||
|
|
||||||
|
# Sanity check 3: they must be equal, and both support indexing
|
||||||
|
for layer_idx in range(10):
|
||||||
|
for key_value_idx in range(2):
|
||||||
|
self.assertTrue(
|
||||||
|
torch.allclose(new_cache[layer_idx][key_value_idx], legacy_cache[layer_idx][key_value_idx])
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test 1: We can convert from legacy to new with no changes
|
||||||
|
from_legacy = DynamicCache.from_legacy_cache(legacy_cache)
|
||||||
|
for layer_idx in range(10):
|
||||||
|
for key_value_idx in range(2):
|
||||||
|
self.assertTrue(
|
||||||
|
torch.allclose(from_legacy[layer_idx][key_value_idx], legacy_cache[layer_idx][key_value_idx])
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test 2: We can convert from new to legacy with no changes
|
||||||
|
to_legacy = new_cache.to_legacy_cache()
|
||||||
|
for layer_idx in range(10):
|
||||||
|
for key_value_idx in range(2):
|
||||||
|
self.assertTrue(
|
||||||
|
torch.allclose(to_legacy[layer_idx][key_value_idx], new_cache[layer_idx][key_value_idx])
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_reorder_cache_retrocompatibility(self):
|
||||||
|
"""Tests that Cache.reorder_cache is retrocompatible with the legacy code path"""
|
||||||
|
legacy_reorder_fn = LlamaForCausalLM._reorder_cache # An example of a legacy `_reorder_cache` function
|
||||||
|
|
||||||
|
legacy_cache = ()
|
||||||
|
new_cache = DynamicCache()
|
||||||
|
|
||||||
|
# Creates a new cache with 10 layers in both formats
|
||||||
|
for layer_idx in range(10):
|
||||||
|
new_key = torch.rand((4, 4, 8, 16))
|
||||||
|
new_value = torch.rand((4, 4, 8, 16))
|
||||||
|
new_cache.update(new_key, new_value, layer_idx)
|
||||||
|
legacy_cache += ((new_key, new_value),)
|
||||||
|
|
||||||
|
# Let's create some dummy beam indices. From the shape above, it is equivalent to the case where num_beams=4
|
||||||
|
# and batch_size=1
|
||||||
|
beam_idx = torch.randint(low=0, high=4, size=(4,))
|
||||||
|
|
||||||
|
legacy_cache_reordered = legacy_reorder_fn(legacy_cache, beam_idx)
|
||||||
|
new_cache.reorder_cache(beam_idx)
|
||||||
|
|
||||||
|
# Let's check that the results are the same
|
||||||
|
for layer_idx in range(10):
|
||||||
|
for key_value_idx in range(2):
|
||||||
|
self.assertTrue(
|
||||||
|
torch.allclose(
|
||||||
|
new_cache[layer_idx][key_value_idx], legacy_cache_reordered[layer_idx][key_value_idx]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch_gpu
|
||||||
|
@slow
|
||||||
|
class CacheIntegrationTest(unittest.TestCase):
|
||||||
|
def test_dynamic_cache_hard(self):
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", padding_side="left")
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
"meta-llama/Llama-2-7b-hf", device_map="auto", torch_dtype=torch.float16
|
||||||
|
)
|
||||||
|
inputs = tokenizer(["Here's everything I know about cats. Cats"], return_tensors="pt").to(model.device)
|
||||||
|
|
||||||
|
# DynamicCache and the legacy cache format should be equivalent
|
||||||
|
set_seed(0)
|
||||||
|
gen_out_legacy = model.generate(**inputs, do_sample=True, max_new_tokens=256)
|
||||||
|
set_seed(0)
|
||||||
|
gen_out = model.generate(**inputs, do_sample=True, max_new_tokens=256, past_key_values=DynamicCache())
|
||||||
|
self.assertListEqual(gen_out_legacy.tolist(), gen_out.tolist())
|
||||||
|
|
||||||
|
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
|
||||||
|
expected_text = (
|
||||||
|
"Here's everything I know about cats. Cats are mysterious creatures. They can't talk, and they don't like "
|
||||||
|
"to be held. They don't play fetch, and they don't like to be hugged. But they do like to be petted.\n"
|
||||||
|
"Cats are also very independent. They don't like to be told what to do, and they don't like to be told "
|
||||||
|
"what to eat. They are also very territorial. They don't like to share their food or their toys.\nCats "
|
||||||
|
"are also very curious. They like to explore, and they like to play. They are also very fast. They can "
|
||||||
|
"run very fast, and they can jump very high.\nCats are also very smart. They can learn tricks, and they "
|
||||||
|
"can solve problems. They are also very playful. They like to play with toys, and they like to play with "
|
||||||
|
"other cats.\nCats are also very affectionate. They like to be petted, and they like to be held. They "
|
||||||
|
"also like to be scratched.\nCats are also very clean. They like to groom themselves, and they like to "
|
||||||
|
"clean their litter box.\nCats are also very independent. They don't"
|
||||||
|
)
|
||||||
|
self.assertEqual(decoded[0], expected_text)
|
||||||
|
|
||||||
|
def test_dynamic_cache_batched(self):
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", padding_side="left")
|
||||||
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
"meta-llama/Llama-2-7b-hf", device_map="auto", torch_dtype=torch.float16
|
||||||
|
)
|
||||||
|
inputs = tokenizer(["A sequence: 1, 2, 3, 4, 5", "A sequence: A, B, C"], padding=True, return_tensors="pt").to(
|
||||||
|
model.device
|
||||||
|
)
|
||||||
|
|
||||||
|
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10, past_key_values=DynamicCache())
|
||||||
|
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
|
||||||
|
expected_text = ["A sequence: 1, 2, 3, 4, 5, 6, 7, 8,", "A sequence: A, B, C, D, E, F, G, H"]
|
||||||
|
self.assertListEqual(decoded, expected_text)
|
||||||
|
|
||||||
|
def test_dynamic_cache_beam_search(self):
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", padding_side="left")
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
"meta-llama/Llama-2-7b-hf", device_map="auto", torch_dtype=torch.float16
|
||||||
|
)
|
||||||
|
|
||||||
|
inputs = tokenizer(["The best color is"], return_tensors="pt").to(model.device)
|
||||||
|
gen_out = model.generate(
|
||||||
|
**inputs,
|
||||||
|
do_sample=False,
|
||||||
|
max_new_tokens=20,
|
||||||
|
num_beams=2,
|
||||||
|
num_return_sequences=2,
|
||||||
|
)
|
||||||
|
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
|
||||||
|
expected_text = [
|
||||||
|
"The best color is the one that makes you feel good.\nThe best color is the one that makes you feel good",
|
||||||
|
"The best color is the one that suits you.\nThe best color is the one that suits you. The",
|
||||||
|
]
|
||||||
|
self.assertListEqual(decoded, expected_text)
|
||||||
|
|
||||||
|
@require_auto_gptq
|
||||||
|
def test_sink_cache_hard(self):
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("TheBloke/LLaMa-7B-GPTQ")
|
||||||
|
model = AutoModelForCausalLM.from_pretrained("TheBloke/LLaMa-7B-GPTQ", device_map="auto")
|
||||||
|
|
||||||
|
inputs = tokenizer(["Vaswani et al. (2017) introduced the Transformers"], return_tensors="pt").to(model.device)
|
||||||
|
|
||||||
|
# Set up the SinkCache. Using a small window length to contain computational complexity. If this example is run
|
||||||
|
# without a SinkCache, the last few tokens are gibberish (ends in "of the of the of a of a of")
|
||||||
|
cache = SinkCache(window_length=508, num_sink_tokens=4)
|
||||||
|
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=3000, past_key_values=cache)
|
||||||
|
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
|
||||||
|
self.assertTrue(decoded[0].endswith("to perform a variety of tasks. The Transformer is a neural network"))
|
||||||
@@ -557,10 +557,6 @@ class ModelTesterMixin:
|
|||||||
return
|
return
|
||||||
|
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
config.use_cache = False
|
|
||||||
config.return_dict = True
|
|
||||||
|
|
||||||
if (
|
if (
|
||||||
model_class.__name__
|
model_class.__name__
|
||||||
in [*get_values(MODEL_MAPPING_NAMES), *get_values(MODEL_FOR_BACKBONE_MAPPING_NAMES)]
|
in [*get_values(MODEL_MAPPING_NAMES), *get_values(MODEL_FOR_BACKBONE_MAPPING_NAMES)]
|
||||||
@@ -569,6 +565,8 @@ class ModelTesterMixin:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
config.use_cache = False
|
||||||
|
config.return_dict = True
|
||||||
model = model_class(config)
|
model = model_class(config)
|
||||||
|
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
|
|||||||
Reference in New Issue
Block a user