Add torch.compile for Mistral (#30642)

* first version

* fix sliding window

* fix style

* add sliding window cache

* fix style

* address comments

* fix test

* fix style

* move sliding window check inside cache init

* revert changes on irrelevant files & add comment on SlidingWindowCache

* address comments & fix style

fix style

* update causal mask

* [run-slow] mistral

* [run-slow] mistral

* [run-slow] mistral

* [run-slow] mistral

* [run-slow] mistral

* [run-slow] llama

* [run-slow] mistral

* [run-slow] mistral

* [run-slow] mistral

* revert CI from a10 to t4

* wrap up
This commit is contained in:
Longjie Zheng
2024-05-20 10:27:24 -04:00
committed by GitHub
parent 92d1d97c05
commit 616bb11d48
19 changed files with 512 additions and 237 deletions

View File

@@ -29,7 +29,7 @@ To optimize this, you can use a kv-cache to store the past keys and values inste
The *static kv-cache* solves this issue by pre-allocating the kv-cache size to a maximum value which allows you to combine it with torch.compile for up to a 4x speed up.
> [!WARNING]
> Currently, only [Command R](./model_doc/cohere), [Gemma](./model_doc/gemma) and [Llama](./model_doc/llama2) models support static kv-cache and torch.compile.
> Currently, only [Llama](./model_doc/llama2) and a few other models support static kv-cache and torch.compile. Check [this issue](https://github.com/huggingface/transformers/issues/28981) for a live model compatibility list.
For this example, let's load the [Gemma](https://hf.co/google/gemma-2b) model.

View File

@@ -448,3 +448,124 @@ class StaticCache(Cache):
# In-place ops prevent breaking the static address
self.key_cache[layer_idx].zero_()
self.value_cache[layer_idx].zero_()
class SlidingWindowCache(Cache):
"""
Sliding Window Cache class to be used with `torch.compile` for models like Mistral that support sliding window attention.
Every time when we try to update the cache, we compute the `indices` based on `cache_position >= self.config.sliding_window_size - 1`,
if true(which means the cache can not hold all the old key value states and new states together because of the sliding window constraint),
we need to do a cycle shift based on `indices` to replace the oldest states by the new key value states passed in.
The `to_shift` is only true once we are above sliding_window_size. Thus with `sliding_window_size==64`:
indices = (slicing + to_shift[-1].int()-1) % self.config.sliding_window_size
tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36,
37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54,
55, 56, 57, 58, 59, 60, 61, 62, 63, 0])
We overwrite the cache using these, then we always write at cache_position (clamped to `sliding_window_size`)
Parameters:
config (`PretrainedConfig):
The configuration file defining the shape-related attributes required to initialize the static cache.
max_batch_size (`int`):
The maximum batch size with which the model will be used.
max_cache_len (`int`):
The maximum sequence length with which the model will be used.
device (`torch.device`):
The device on which the cache should be initialized. Should be the same as the layer.
dtype (*optional*, defaults to `torch.float32`):
The default `dtype` to use when initializing the layer.
"""
def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device, dtype=None) -> None:
if not hasattr(config, "sliding_window") or config.sliding_window is None:
raise ValueError(
"Setting `cache_implementation` to 'sliding_window' requires the model config supporting "
"sliding window attention, please check if there is a `sliding_window` field in the model "
"config and it's not set to None."
)
super().__init__()
self.max_batch_size = max_batch_size
# take the minimum of max_cache_len and config.sliding_window so that we allocate less memory
# when we do short-sentence generation
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
self.model_sliding_window_size = config.sliding_window
self.sliding_window_size = min(self.max_cache_len, self.model_sliding_window_size)
# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
self.head_dim = (
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
)
self.dtype = dtype if dtype is not None else torch.float32
self.num_key_value_heads = (
config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
)
cache_shape = (
config.num_hidden_layers,
max_batch_size,
self.num_key_value_heads,
self.sliding_window_size,
self.head_dim,
)
self.key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
self.value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
torch._dynamo.mark_static_address(self.key_cache)
torch._dynamo.mark_static_address(self.value_cache)
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor]:
cache_position = cache_kwargs.get("cache_position")
k_out = self.key_cache[layer_idx]
v_out = self.value_cache[layer_idx]
# assume this only happens in prefill phase when prompt length > sliding_window_size
if cache_position.shape[0] > self.sliding_window_size:
k_out = key_states[:, :, -self.sliding_window_size :, :]
v_out = value_states[:, :, -self.sliding_window_size :, :]
self.key_cache[layer_idx] = k_out
self.value_cache[layer_idx] = v_out
# we should return the whole states instead of k_out, v_out to take the whole prompt
# into consideration when building kv cache instead of just throwing away tokens outside of the window
return key_states, value_states
slicing = torch.ones(self.sliding_window_size, dtype=torch.long, device=value_states.device).cumsum(0)
cache_position = cache_position.clamp(0, self.sliding_window_size - 1)
to_shift = cache_position >= self.sliding_window_size - 1
indices = (slicing + to_shift[-1].int() - 1) % self.sliding_window_size
k_out = k_out[:, :, indices]
v_out = v_out[:, :, indices]
k_out[:, :, cache_position] = key_states
v_out[:, :, cache_position] = value_states
self.key_cache[layer_idx] = k_out
self.value_cache[layer_idx] = v_out
return k_out, v_out
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
# assume this will be called only in the first generation step
# `cache_postion` will be used in other cases
return 0
def get_max_length(self) -> Optional[int]:
# in theory there is no limit because the sliding window size is fixed
# no matter how long the sentence is
return None
def reset(self):
self.key_cache.zero_()
self.value_cache.zero_()

View File

@@ -24,7 +24,7 @@ import torch
import torch.distributed as dist
from torch import nn
from ..cache_utils import Cache, DynamicCache, StaticCache
from ..cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
from ..integrations.deepspeed import is_deepspeed_zero3_enabled
from ..modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput
from ..models.auto import (
@@ -96,9 +96,7 @@ logger = logging.get_logger(__name__)
if is_accelerate_available():
from accelerate.hooks import AlignDevicesHook, add_hook_to_module
NEED_SETUP_CACHE_CLASSES_MAPPING = {
"static": StaticCache,
}
NEED_SETUP_CACHE_CLASSES_MAPPING = {"static": StaticCache, "sliding_window": SlidingWindowCache}
@dataclass
@@ -1326,24 +1324,33 @@ class GenerationMixin:
model_kwargs["cache_position"] = torch.arange(past_length, cur_len, device=input_ids.device)
return model_kwargs
def _get_static_cache(self, max_batch_size: int, max_cache_len: int) -> StaticCache:
def _get_cache(self, cache_implementation: str, max_batch_size: int, max_cache_len: int) -> Cache:
"""
Sets a static cache for `generate`, that will persist across calls. A new cache will only be initialized a
Sets a cache for `generate`, that will persist across calls. A new cache will only be initialized a
new `generate` call requires a larger cache.
Returns the resulting static cache object.
Returns the resulting cache object.
"""
needs_new_cache = (
not hasattr(self, "_static_cache")
or self._static_cache.max_batch_size < max_batch_size
or self._static_cache.max_cache_len < max_cache_len
cache_cls: Cache = NEED_SETUP_CACHE_CLASSES_MAPPING[cache_implementation]
need_new_cache = (
not hasattr(self, "_cache")
or (not isinstance(self._cache, cache_cls))
or self._cache.max_batch_size < max_batch_size
)
if needs_new_cache:
if cache_implementation == "sliding_window":
need_new_cache = need_new_cache or (
self._cache.sliding_window_size < self._cache.model_sliding_window_size
and max_cache_len > self._cache.max_cache_len
)
elif cache_implementation == "static":
need_new_cache = need_new_cache or self._cache.max_cache_len < max_cache_len
if need_new_cache:
if hasattr(self.config, "_pre_quantization_dtype"):
cache_dtype = self.config._pre_quantization_dtype
else:
cache_dtype = self.dtype
self._static_cache = StaticCache(
self._cache = cache_cls(
config=self.config,
max_batch_size=max_batch_size,
max_cache_len=max_cache_len,
@@ -1351,8 +1358,8 @@ class GenerationMixin:
dtype=cache_dtype,
)
else:
self._static_cache.reset() # reset the cache for a new generation
return self._static_cache
self._cache.reset()
return self._cache
def _prepare_special_tokens(
self,
@@ -1615,14 +1622,14 @@ class GenerationMixin:
"This model does not support the `cache_implementation` argument. Please check the following "
"issue: https://github.com/huggingface/transformers/issues/28981."
)
if generation_config.cache_implementation == "static":
if not self._supports_static_cache:
if generation_config.cache_implementation == "static" and not self._supports_static_cache:
raise ValueError(
"This model does not support `cache_implementation='static'`. Please check the following "
"issue: https://github.com/huggingface/transformers/issues/28981"
)
model_kwargs["past_key_values"] = self._get_static_cache(batch_size, generation_config.max_length)
model_kwargs["past_key_values"] = self._get_cache(
generation_config.cache_implementation, batch_size, generation_config.max_length
)
self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
# 7. determine generation mode

View File

@@ -63,7 +63,7 @@ class OpenLlamaRMSNorm(nn.Module):
return self.weight * hidden_states.to(input_dtype)
# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->OpenLlama
# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->OpenLlama
class OpenLlamaRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
@@ -154,7 +154,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)
# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb
# Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.

View File

@@ -81,7 +81,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)
# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb
# Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
@@ -123,7 +123,7 @@ def _get_unpad_data(attention_mask):
)
# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Falcon
# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Falcon
class FalconRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()

View File

@@ -253,7 +253,6 @@ class GemmaAttention(nn.Module):
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
@@ -265,8 +264,8 @@ class GemmaAttention(nn.Module):
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None)
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache

View File

@@ -522,7 +522,7 @@ def attention_mask_func(attention_scores, ltor_mask):
class GPTNeoXRotaryEmbedding(nn.Module):
# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding.__init__
# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding.__init__
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
@@ -614,7 +614,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)
# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb
# Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.

View File

@@ -230,7 +230,7 @@ class GPTNeoXJapaneseAttention(nn.Module):
# Copied from transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXRotaryEmbedding with GPTNeoXRotaryEmbedding->RotaryEmbedding
class RotaryEmbedding(nn.Module):
# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding.__init__
# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding.__init__
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()

View File

@@ -478,7 +478,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)
# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb
# Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.

View File

@@ -303,7 +303,6 @@ class LlamaAttention(nn.Module):
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
@@ -402,7 +401,6 @@ class LlamaFlashAttention2(LlamaAttention):
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if isinstance(past_key_value, StaticCache):
raise ValueError(

View File

@@ -18,6 +18,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch Mistral model."""
import inspect
import math
from typing import List, Optional, Tuple, Union
@@ -29,8 +30,8 @@ from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa
from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
@@ -92,8 +93,6 @@ class MistralRMSNorm(nn.Module):
return self.weight * hidden_states.to(input_dtype)
# copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Mistral
# TODO @Arthur no longer copied from LLama after static cache
class MistralRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
@@ -104,30 +103,22 @@ class MistralRotaryEmbedding(nn.Module):
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
# Build here to make `torch.jit.trace` work.
self._set_cos_sin_cache(
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
)
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
def forward(self, x, seq_len=None):
@torch.no_grad()
# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.forward
def forward(self, x, position_ids):
# x: [bs, num_attention_heads, seq_len, head_size]
if seq_len > self.max_seq_len_cached:
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
return (
self.cos_cached[:seq_len].to(dtype=x.dtype),
self.sin_cached[:seq_len].to(dtype=x.dtype),
)
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
# Force float32 since bfloat16 loses precision on long contexts
# See https://github.com/huggingface/transformers/pull/29285
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
# Copied from transformers.models.llama.modeling_llama.rotate_half
@@ -138,9 +129,8 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)
# copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
# TODO @Arthur no longer copied from LLama after static cache
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
@@ -148,9 +138,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`):
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
used to pass offsetted position ids when working with a KV-cache.
position_ids (`torch.Tensor`, *optional*):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
@@ -161,8 +150,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos[position_ids].unsqueeze(unsqueeze_dim)
sin = sin[position_ids].unsqueeze(unsqueeze_dim)
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
@@ -213,6 +202,7 @@ class MistralAttention(nn.Module):
"when creating this class."
)
self.attention_dropout = config.attention_dropout
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
@@ -221,7 +211,6 @@ class MistralAttention(nn.Module):
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
self.is_causal = True
self.attention_dropout = config.attention_dropout
if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError(
@@ -231,7 +220,7 @@ class MistralAttention(nn.Module):
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
self.rotary_emb = MistralRotaryEmbedding(
self.head_dim,
@@ -239,9 +228,7 @@ class MistralAttention(nn.Module):
base=self.rope_theta,
)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
# Copied from transformers.models.gemma.modeling_gemma.GemmaAttention.forward with Gemma->Mistral
def forward(
self,
hidden_states: torch.Tensor,
@@ -250,6 +237,7 @@ class MistralAttention(nn.Module):
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
@@ -261,41 +249,22 @@ class MistralAttention(nn.Module):
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
if self.layer_idx is None:
raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
@@ -309,8 +278,8 @@ class MistralAttention(nn.Module):
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = attn_output.view(bsz, q_len, -1)
attn_output = self.o_proj(attn_output)
if not output_attentions:
@@ -343,7 +312,16 @@ class MistralFlashAttention2(MistralAttention):
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
):
if isinstance(past_key_value, StaticCache):
raise ValueError(
"`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
"make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
)
output_attentions = False
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
@@ -356,19 +334,10 @@ class MistralFlashAttention2(MistralAttention):
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
if self.layer_idx is None:
raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
kv_seq_len += cache_position[0]
# Because the input can be padded, the absolute sequence length depends on the max position id.
rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
use_sliding_windows = (
_flash_supports_window_size
@@ -605,8 +574,7 @@ class MistralFlashAttention2(MistralAttention):
)
# copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Mistral
# TODO @Arthur no longer copied from LLama after static cache
# Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Mistral
class MistralSdpaAttention(MistralAttention):
"""
Mistral attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
@@ -623,6 +591,7 @@ class MistralSdpaAttention(MistralAttention):
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions:
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
@@ -637,6 +606,7 @@ class MistralSdpaAttention(MistralAttention):
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
)
bsz, q_len, _ = hidden_states.size()
@@ -649,43 +619,38 @@ class MistralSdpaAttention(MistralAttention):
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
causal_mask = attention_mask
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if query_states.device.type == "cuda" and attention_mask is not None:
if query_states.device.type == "cuda" and causal_mask is not None:
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
is_causal = True if self.is_causal and attention_mask is None and q_len > 1 else False
is_causal = True if causal_mask is None and q_len > 1 else False
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
attn_mask=causal_mask,
dropout_p=self.attention_dropout if self.training else 0.0,
is_causal=is_causal,
)
@@ -705,12 +670,13 @@ MISTRAL_ATTENTION_CLASSES = {
}
# Copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with Llama->Mistral, LLAMA->MISTRAL
class MistralDecoderLayer(nn.Module):
def __init__(self, config: MistralConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = MISTRAL_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
self.self_attn = MISTRAL_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
self.mlp = MistralMLP(config)
self.input_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -721,15 +687,17 @@ class MistralDecoderLayer(nn.Module):
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
past_key_value: Optional[Cache] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, sequence_length)` where padding elements are indicated by 0.
attention_mask (`torch.FloatTensor`, *optional*):
attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
query_sequence_length, key_sequence_length)` if default attention is used.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
@@ -751,6 +719,7 @@ class MistralDecoderLayer(nn.Module):
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
)
hidden_states = residual + hidden_states
@@ -801,6 +770,7 @@ class MistralPreTrainedModel(PreTrainedModel):
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True
_supports_static_cache = True
def _init_weights(self, module):
std = self.config.initializer_range
@@ -924,12 +894,13 @@ class MistralModel(MistralPreTrainedModel):
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
@@ -940,72 +911,36 @@ class MistralModel(MistralPreTrainedModel):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
)
if self.gradient_checkpointing and self.training:
if use_cache:
if self.gradient_checkpointing and self.training and use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
past_key_values_length = 0
if use_cache:
use_legacy_cache = not isinstance(past_key_values, Cache)
if use_legacy_cache:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_key_values_length = past_key_values.get_usable_length(seq_length)
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache:
is_padding_right = attention_mask[:, -1].sum().item() != batch_size
if is_padding_right:
raise ValueError(
"You are attempting to perform batched generation with padding_side='right'"
" this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to "
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
return_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache):
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
return_legacy_cache = True
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
if self._attn_implementation == "flash_attention_2":
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
elif self._attn_implementation == "sdpa" and not output_attentions:
# output_attentions=True can not be supported when using SDPA, and we fall back on
# the manual implementation that requires a 4D causal mask in all cases.
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
sliding_window=self.config.sliding_window,
)
else:
# 4d mask is passed through the layers
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
sliding_window=self.config.sliding_window,
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, use_cache, output_attentions
)
hidden_states = inputs_embeds
@@ -1023,20 +958,22 @@ class MistralModel(MistralPreTrainedModel):
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,
causal_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
cache_position,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
)
hidden_states = layer_outputs[0]
@@ -1053,9 +990,9 @@ class MistralModel(MistralPreTrainedModel):
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = None
if use_cache:
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
next_cache = next_decoder_cache if use_cache else None
if return_legacy_cache:
next_cache = next_cache.to_legacy_cache()
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
@@ -1066,6 +1003,113 @@ class MistralModel(MistralPreTrainedModel):
attentions=all_self_attns,
)
def _update_causal_mask(
self,
attention_mask: torch.Tensor,
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache,
use_cache: bool,
output_attentions: bool,
):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
if self._attn_implementation == "flash_attention_2":
if attention_mask is not None and use_cache:
is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0]
if is_padding_right:
raise ValueError(
"You are attempting to perform batched generation with padding_side='right'"
" this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to "
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
)
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
# cache_position must be valid here no matter which cache we use
past_seen_tokens = cache_position[0] if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache)
using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache)
if (
self.config._attn_implementation == "sdpa"
and not (using_static_cache or using_sliding_window_cache)
and not output_attentions
):
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
past_key_values_length=past_seen_tokens,
sliding_window=self.config.sliding_window,
is_training=self.training,
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
# SlidingWindowCache
if using_sliding_window_cache:
target_length = max(sequence_length, self.config.sliding_window)
# StaticCache
elif using_static_cache:
target_length = past_key_values.get_max_length()
# DynamicCache or no cache
else:
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens + sequence_length + 1
)
if attention_mask is not None and attention_mask.dim() == 4:
# in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
if attention_mask.max() != 0:
raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`")
causal_mask = attention_mask
else:
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
)
exclude_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
if self.config.sliding_window is not None:
if not using_sliding_window_cache or sequence_length > self.config.sliding_window:
exclude_mask |= torch.arange(target_length, device=device) <= (
cache_position.reshape(-1, 1) - self.config.sliding_window
)
causal_mask *= exclude_mask
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
if attention_mask.dim() == 2:
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type == "cuda"
and not output_attentions
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
return causal_mask
class MistralForCausalLM(MistralPreTrainedModel):
_tied_weights_keys = ["lm_head.weight"]
@@ -1104,13 +1148,14 @@ class MistralForCausalLM(MistralPreTrainedModel):
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
@@ -1155,6 +1200,7 @@ class MistralForCausalLM(MistralPreTrainedModel):
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
)
hidden_states = outputs[0]
@@ -1187,14 +1233,27 @@ class MistralForCausalLM(MistralPreTrainedModel):
)
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
cache_position=None,
use_cache=True,
**kwargs,
):
# Omit tokens covered by past_key_values
past_length = 0
if past_key_values is not None:
if isinstance(past_key_values, Cache):
cache_length = past_key_values.get_seq_length()
past_length = past_key_values.seen_tokens
max_cache_length = past_key_values.get_max_length()
past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
max_cache_length = (
torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
if past_key_values.get_max_length() is not None
else None
)
cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)
# TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
else:
cache_length = past_length = past_key_values[0][0].shape[2]
max_cache_length = None
@@ -1227,17 +1286,33 @@ class MistralForCausalLM(MistralPreTrainedModel):
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]
# crop the attention_mask to sliding window size during decode phase if using SlidingWindowCache
if (
past_length > 0
and attention_mask is not None
and isinstance(past_key_values, SlidingWindowCache)
and attention_mask.shape[1] > past_key_values.sliding_window_size
):
attention_mask = attention_mask[:, -past_key_values.sliding_window_size :]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs = {"input_ids": input_ids.contiguous()}
input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
if cache_position is None:
cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device)
elif use_cache:
cache_position = cache_position[-input_length:]
model_inputs.update(
{
"position_ids": position_ids,
"cache_position": cache_position,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"use_cache": use_cache,
"attention_mask": attention_mask,
}
)

View File

@@ -181,7 +181,8 @@ class MixtralRMSNorm(nn.Module):
return self.weight * hidden_states.to(input_dtype)
# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Mixtral
# copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Mixtral
# TODO @longjie no longer copied from Mistral after static cache
class MixtralRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
@@ -226,7 +227,8 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)
# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb
# copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb
# TODO @longjie no longer copied from Mistral after static cache
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
@@ -268,7 +270,8 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
# Copied from transformers.models.mistral.modeling_mistral.MistralAttention with Mistral->Mixtral
# copied from transformers.models.mistral.modeling_mistral.MistralAttention with Mistral->Mixtral
# TODO @longjie no longer copied from Mistral after static cache
class MixtralAttention(nn.Module):
"""
Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
@@ -392,7 +395,8 @@ class MixtralAttention(nn.Module):
return attn_output, attn_weights, past_key_value
# Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2 with Mistral->Mixtral
# copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2 with Mistral->Mixtral
# TODO @longjie no longer copied from Mistral after static cache
class MixtralFlashAttention2(MixtralAttention):
"""
Mixtral flash attention module. This module inherits from `MixtralAttention` as the weights of the module stays
@@ -679,7 +683,8 @@ class MixtralFlashAttention2(MixtralAttention):
)
# Copied from transformers.models.mistral.modeling_mistral.MistralSdpaAttention with Mistral->Mixtral
# copied from transformers.models.mistral.modeling_mistral.MistralSdpaAttention with Mistral->Mixtral
# TODO @longjie no longer copied from Mistral after static cache
class MixtralSdpaAttention(MixtralAttention):
"""
Mixtral attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
@@ -958,7 +963,7 @@ MIXTRAL_START_DOCSTRING = r"""
"The bare Mixtral Model outputting raw hidden-states without any specific head on top.",
MIXTRAL_START_DOCSTRING,
)
# Copied from transformers.models.mistral.modeling_mistral.MistralPreTrainedModel with Mistral->Mixtral
# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2PreTrainedModel with Qwen2->Mixtral
class MixtralPreTrainedModel(PreTrainedModel):
config_class = MixtralConfig
base_model_prefix = "model"
@@ -1052,7 +1057,8 @@ MIXTRAL_INPUTS_DOCSTRING = r"""
"The bare Mixtral Model outputting raw hidden-states without any specific head on top.",
MIXTRAL_START_DOCSTRING,
)
# Copied from transformers.models.mistral.modeling_mistral.MistralModel with MISTRAL->MIXTRAL,Mistral->Mixtral
# copied from transformers.models.mistral.modeling_mistral.MistralModel with MISTRAL->MIXTRAL,Mistral->Mixtral
# TODO @longjie no longer copied from Mistral after static cache
class MixtralModel(MixtralPreTrainedModel):
"""
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MixtralDecoderLayer`]

View File

@@ -45,7 +45,7 @@ logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "PersimmonConfig"
# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Persimmon
# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Persimmon
class PersimmonRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
@@ -137,7 +137,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)
# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb
# Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.

View File

@@ -76,7 +76,7 @@ def _get_unpad_data(attention_mask):
)
# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Phi
# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Phi
class PhiRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
@@ -168,7 +168,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)
# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb
# Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.

View File

@@ -94,7 +94,7 @@ class Qwen2RMSNorm(nn.Module):
return self.weight * hidden_states.to(input_dtype)
# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Qwen2
# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Qwen2
class Qwen2RotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
@@ -139,7 +139,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)
# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb
# Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
@@ -611,7 +611,7 @@ class Qwen2FlashAttention2(Qwen2Attention):
)
# Copied from transformers.models.mistral.modeling_mistral.MistralSdpaAttention with Mistral->Qwen2
# Copied from transformers.models.mixtral.modeling_mixtral.MixtralSdpaAttention with Mixtral->Qwen2
class Qwen2SdpaAttention(Qwen2Attention):
"""
Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from

View File

@@ -170,7 +170,7 @@ class Qwen2MoeRMSNorm(nn.Module):
return self.weight * hidden_states.to(input_dtype)
# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Qwen2Moe
# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Qwen2Moe
class Qwen2MoeRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
@@ -215,7 +215,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)
# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb
# Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
@@ -689,7 +689,7 @@ class Qwen2MoeFlashAttention2(Qwen2MoeAttention):
)
# Copied from transformers.models.mistral.modeling_mistral.MistralSdpaAttention with Mistral->Qwen2Moe
# Copied from transformers.models.mixtral.modeling_mixtral.MixtralSdpaAttention with Mixtral->Qwen2Moe
class Qwen2MoeSdpaAttention(Qwen2MoeAttention):
"""
Qwen2Moe attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from

View File

@@ -71,7 +71,7 @@ def _get_unpad_data(attention_mask):
)
# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->StableLm
# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->StableLm
class StableLmRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
@@ -163,7 +163,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)
# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb
# Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.

View File

@@ -74,7 +74,7 @@ def _get_unpad_data(attention_mask):
)
# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Starcoder2
# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Starcoder2
class Starcoder2RotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
@@ -119,7 +119,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)
# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb
# Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
@@ -590,7 +590,7 @@ class Starcoder2FlashAttention2(Starcoder2Attention):
)
# Copied from transformers.models.mistral.modeling_mistral.MistralSdpaAttention with Mistral->Starcoder2
# Copied from transformers.models.mixtral.modeling_mixtral.MixtralSdpaAttention with Mixtral->Starcoder2
class Starcoder2SdpaAttention(Starcoder2Attention):
"""
Starcoder2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
@@ -703,7 +703,7 @@ class Starcoder2DecoderLayer(nn.Module):
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon)
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon)
# Copied from transformers.models.mistral.modeling_mistral.MistralDecoderLayer.forward
# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2DecoderLayer.forward
def forward(
self,
hidden_states: torch.Tensor,
@@ -780,7 +780,7 @@ STARCODER2_START_DOCSTRING = r"""
"The bare Starcoder2 Model outputting raw hidden-states without any specific head on top.",
STARCODER2_START_DOCSTRING,
)
# Copied from transformers.models.mistral.modeling_mistral.MistralPreTrainedModel with Mistral->Starcoder2
# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2PreTrainedModel with Qwen2->Starcoder2
class Starcoder2PreTrainedModel(PreTrainedModel):
config_class = Starcoder2Config
base_model_prefix = "model"
@@ -1057,7 +1057,7 @@ class Starcoder2Model(Starcoder2PreTrainedModel):
)
# Copied from transformers.models.mistral.modeling_mistral.MistralForCausalLM with MISTRAL->STARCODER2,Mistral-7B-v0.1->starcoder2-7b_16k,Mistral->Starcoder2,mistralai->bigcode
# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2ForCausalLM with QWEN2->STARCODER2,Qwen2->Starcoder2
class Starcoder2ForCausalLM(Starcoder2PreTrainedModel):
_tied_weights_keys = ["lm_head.weight"]
@@ -1090,6 +1090,7 @@ class Starcoder2ForCausalLM(Starcoder2PreTrainedModel):
@add_start_docstrings_to_model_forward(STARCODER2_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
# Ignore copy
def forward(
self,
input_ids: torch.LongTensor = None,

View File

@@ -14,12 +14,12 @@
# limitations under the License.
"""Testing suite for the PyTorch Mistral model."""
import gc
import tempfile
import unittest
import pytest
from packaging import version
from transformers import AutoTokenizer, MistralConfig, is_torch_available, set_seed
from transformers.testing_utils import (
@@ -648,6 +648,74 @@ class MistralIntegrationTest(unittest.TestCase):
backend_empty_cache(torch_device)
gc.collect()
@slow
def test_compile_static_cache(self):
# `torch==2.2` will throw an error on this test (as in other compilation tests), but torch==2.1.2 and torch>2.2
# work as intended. See https://github.com/pytorch/pytorch/issues/121943
if version.parse(torch.__version__) < version.parse("2.3.0"):
self.skipTest("This test requires torch >= 2.3 to run.")
NUM_TOKENS_TO_GENERATE = 40
EXPECTED_TEXT_COMPLETION = {
8: [
"My favourite condiment is 100% ketchup. I love it on everything. "
"Im not a big fan of mustard, mayo, or relish. Im not a fan of pickles"
],
7: [
"My favourite condiment is 100% ketchup. I love it on everything. "
"Im not a big fan of mustard, mayo, or relish. Im not a fan of pickles"
],
}
prompts = ["My favourite condiment is "]
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", use_fast=False)
tokenizer.pad_token = tokenizer.eos_token
model = MistralForCausalLM.from_pretrained(
"mistralai/Mistral-7B-v0.1", device_map="sequential", torch_dtype=torch.float16
)
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)
# Dynamic Cache
generated_ids = model.generate(**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False)
dynamic_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], dynamic_text)
# Static Cache
generated_ids = model.generate(
**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static"
)
static_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], static_text)
# Sliding Window Cache
generated_ids = model.generate(
**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="sliding_window"
)
static_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], static_text)
# Static Cache + compile
forward_function = model.forward
model.forward = torch.compile(forward_function, mode="reduce-overhead", fullgraph=True)
generated_ids = model.generate(
**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static"
)
static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], static_compiled_text)
# Sliding Window Cache + compile
torch._dynamo.reset()
model.forward = torch.compile(forward_function, mode="reduce-overhead", fullgraph=True)
generated_ids = model.generate(
**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="sliding_window"
)
static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], static_compiled_text)
del model
backend_empty_cache(torch_device)
gc.collect()
@slow
@require_torch_gpu