[Mistral] Add Flash Attention-2 support for mistral (#26464)
* add FA-2 support for mistral * fixup * add sliding windows * fixing few nits * v1 slicing cache - logits do not match * add comment * fix bugs * more mem efficient * add warning once * add warning once * oops * fixup * more comments * copy * add safety checker * fixup * Update src/transformers/models/mistral/modeling_mistral.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * copied from * up * raise when padding side is right * fixup * add doc + few minor changes * fixup --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
@@ -82,6 +82,51 @@ tokenizer = LlamaTokenizer.from_pretrained("/output/path")
|
|||||||
model = MistralForCausalLM.from_pretrained("/output/path")
|
model = MistralForCausalLM.from_pretrained("/output/path")
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Combining Mistral and Flash Attention 2
|
||||||
|
|
||||||
|
First, make sure to install the latest version of Flash Attention 2 to include the sliding window attention feature.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install -U flash-attn --no-build-isolation
|
||||||
|
```
|
||||||
|
|
||||||
|
Make also sure that you have a hardware that is compatible with Flash-Attention 2. Read more about it in the official documentation of [`flash-attn`](https://github.com/Dao-AILab/flash-attention) repository. Make also sure to load your model in half-precision (e.g. `torch.float16`)
|
||||||
|
|
||||||
|
To load and run a model using Flash Attention 2, refer to the snippet below:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> import torch
|
||||||
|
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
>>> device = "cuda" # the device to load the model onto
|
||||||
|
|
||||||
|
>>> model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", torch_dtype=torch.float16, use_flash_attention_2=True)
|
||||||
|
>>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
|
||||||
|
|
||||||
|
>>> prompt = "My favourite condiment is"
|
||||||
|
|
||||||
|
>>> model_inputs = tokenizer([prompt], return_tensors="pt").to(device)
|
||||||
|
>>> model.to(device)
|
||||||
|
|
||||||
|
>>> generated_ids = model.generate(**model_inputs, max_new_tokens=100, do_sample=True)
|
||||||
|
>>> tokenizer.batch_decode(generated_ids)[0]
|
||||||
|
"The expected outupt"
|
||||||
|
```
|
||||||
|
|
||||||
|
### Expected speedups
|
||||||
|
|
||||||
|
Below is a expected speedup diagram that compares pure inference time between the native implementation in transformers using `mistralai/Mistral-7B-v0.1` checkpoint and the Flash Attention 2 version of the model.
|
||||||
|
|
||||||
|
<div style="text-align: center">
|
||||||
|
<img src="https://huggingface.co/datasets/ybelkada/documentation-images/resolve/main/mistral-7b-inference-large-seqlen.png">
|
||||||
|
</div>
|
||||||
|
|
||||||
|
### Sliding window Attention
|
||||||
|
|
||||||
|
The current implementation supports the sliding window attention mechanism and memory efficient cache management.
|
||||||
|
To enable sliding window attention, just make sure to have a `flash-attn` version that is compatible with sliding window attention (`>=2.3.0`).
|
||||||
|
|
||||||
|
The Flash Attention-2 model uses also a more memory efficient cache slicing mechanism - as recommended per the official implementation of Mistral model that use rolling cache mechanism we keep the cache size fixed (`self.config.sliding_window`), support batched generation only for `padding_side="left"` and use the absolute position of the current token to compute the positional embedding.
|
||||||
|
|
||||||
## The Mistral Team
|
## The Mistral Team
|
||||||
|
|
||||||
Albert Jiang, Alexandre Sablayrolles, Arthur Mensch, Chris Bamford, Devendra Singh Chaplot, Diego de las Casas, Florian Bressand, Gianna Lengyel, Guillaume Lample, Lélio Renard Lavaud, Lucile Saulnier, Marie-Anne Lachaux, Pierre Stock, Teven Le Scao, Thibaut Lavril, Thomas Wang, Timothée Lacroix, William El Sayed.
|
Albert Jiang, Alexandre Sablayrolles, Arthur Mensch, Chris Bamford, Devendra Singh Chaplot, Diego de las Casas, Florian Bressand, Gianna Lengyel, Guillaume Lample, Lélio Renard Lavaud, Lucile Saulnier, Marie-Anne Lachaux, Pierre Stock, Teven Le Scao, Thibaut Lavril, Thomas Wang, Timothée Lacroix, William El Sayed.
|
||||||
|
|||||||
@@ -32,6 +32,7 @@ Make sure to follow the installation guide on the repository mentioned above to
|
|||||||
We natively support Flash Attention 2 for the following models:
|
We natively support Flash Attention 2 for the following models:
|
||||||
|
|
||||||
- Llama
|
- Llama
|
||||||
|
- Mistral
|
||||||
- Falcon
|
- Falcon
|
||||||
|
|
||||||
You can request to add Flash Attention 2 support for more models by opening an issue on GitHub, and even open a Pull Request to integrate the changes. The supported models can be used for inference and training, including training with padding tokens - *which is currently not supported for `BetterTransformer` API below.*
|
You can request to add Flash Attention 2 support for more models by opening an issue on GitHub, and even open a Pull Request to integrate the changes. The supported models can be used for inference and training, including training with padding tokens - *which is currently not supported for `BetterTransformer` API below.*
|
||||||
|
|||||||
@@ -18,10 +18,12 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" PyTorch Mistral model."""
|
""" PyTorch Mistral model."""
|
||||||
|
import inspect
|
||||||
import math
|
import math
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||||
@@ -29,15 +31,41 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
|||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
|
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
|
||||||
from ...modeling_utils import PreTrainedModel
|
from ...modeling_utils import PreTrainedModel
|
||||||
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
from ...utils import (
|
||||||
|
add_start_docstrings,
|
||||||
|
add_start_docstrings_to_model_forward,
|
||||||
|
is_flash_attn_available,
|
||||||
|
logging,
|
||||||
|
replace_return_docstrings,
|
||||||
|
)
|
||||||
from .configuration_mistral import MistralConfig
|
from .configuration_mistral import MistralConfig
|
||||||
|
|
||||||
|
|
||||||
|
if is_flash_attn_available():
|
||||||
|
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
||||||
|
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
||||||
|
|
||||||
|
_flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
_CONFIG_FOR_DOC = "MistralConfig"
|
_CONFIG_FOR_DOC = "MistralConfig"
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
|
||||||
|
def _get_unpad_data(padding_mask):
|
||||||
|
seqlens_in_batch = padding_mask.sum(dim=-1, dtype=torch.int32)
|
||||||
|
indices = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten()
|
||||||
|
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
||||||
|
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
|
||||||
|
return (
|
||||||
|
indices,
|
||||||
|
cu_seqlens,
|
||||||
|
max_seqlen_in_batch,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _make_sliding_window_causal_mask(
|
def _make_sliding_window_causal_mask(
|
||||||
input_ids_shape: torch.Size,
|
input_ids_shape: torch.Size,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
@@ -226,6 +254,7 @@ class MistralAttention(nn.Module):
|
|||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
|
padding_mask: Optional[torch.Tensor] = None,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
@@ -291,11 +320,271 @@ class MistralAttention(nn.Module):
|
|||||||
return attn_output, attn_weights, past_key_value
|
return attn_output, attn_weights, past_key_value
|
||||||
|
|
||||||
|
|
||||||
|
class MistralFlashAttention2(MistralAttention):
|
||||||
|
"""
|
||||||
|
Mistral flash attention module. This module inherits from `MistralAttention` as the weights of the module stays
|
||||||
|
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
|
||||||
|
flash attention and deal with padding tokens in case the input contains any of them.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
use_cache: bool = False,
|
||||||
|
padding_mask: Optional[torch.LongTensor] = None,
|
||||||
|
):
|
||||||
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
query_states = self.q_proj(hidden_states)
|
||||||
|
key_states = self.k_proj(hidden_states)
|
||||||
|
value_states = self.v_proj(hidden_states)
|
||||||
|
|
||||||
|
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
|
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
kv_seq_len = key_states.shape[-2]
|
||||||
|
if past_key_value is not None:
|
||||||
|
kv_seq_len += past_key_value[0].shape[-2]
|
||||||
|
|
||||||
|
# Because the input can be padded, the absolute sequence length depends on the max position id.
|
||||||
|
rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
|
||||||
|
cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)
|
||||||
|
|
||||||
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||||
|
|
||||||
|
use_sliding_windows = (
|
||||||
|
_flash_supports_window_size
|
||||||
|
and hasattr(self.config, "sliding_window") is not None
|
||||||
|
and kv_seq_len > self.config.sliding_window
|
||||||
|
)
|
||||||
|
|
||||||
|
if not _flash_supports_window_size:
|
||||||
|
logger.warning_once(
|
||||||
|
"The current flash attention version does not support sliding window attention, for a more memory efficient implementation"
|
||||||
|
" make sure to upgrade flash-attn library."
|
||||||
|
)
|
||||||
|
|
||||||
|
if past_key_value is not None:
|
||||||
|
# Activate slicing cache only if the config has a value `sliding_windows` attribute
|
||||||
|
if hasattr(self.config, "sliding_window") and kv_seq_len > self.config.sliding_window:
|
||||||
|
slicing_tokens = kv_seq_len - self.config.sliding_window
|
||||||
|
|
||||||
|
past_key = past_key_value[0]
|
||||||
|
past_value = past_key_value[1]
|
||||||
|
|
||||||
|
past_key = past_key[:, :, slicing_tokens:, :].contiguous()
|
||||||
|
past_value = past_value[:, :, slicing_tokens:, :].contiguous()
|
||||||
|
|
||||||
|
if past_key.shape[-2] != self.config.sliding_window - 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"past key much have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
|
||||||
|
f" {past_key.shape}"
|
||||||
|
)
|
||||||
|
|
||||||
|
past_key_value = (past_key, past_value)
|
||||||
|
|
||||||
|
if padding_mask is not None:
|
||||||
|
padding_mask = padding_mask[:, slicing_tokens:]
|
||||||
|
padding_mask = torch.cat([padding_mask, torch.ones_like(padding_mask[:, -1:])], dim=-1)
|
||||||
|
|
||||||
|
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||||
|
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||||
|
|
||||||
|
past_key_value = (key_states, value_states) if use_cache else None
|
||||||
|
|
||||||
|
# repeat k/v heads if n_kv_heads < n_heads
|
||||||
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||||
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||||
|
|
||||||
|
# TODO: Mistral does not have dropout in the config??
|
||||||
|
# It is recommended to use dropout with FA according to the docs
|
||||||
|
# when training.
|
||||||
|
dropout_rate = 0.0 # if not self.training else self.attn_dropout
|
||||||
|
|
||||||
|
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
||||||
|
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
||||||
|
# cast them back in float16 just to be sure everything works as expected.
|
||||||
|
input_dtype = query_states.dtype
|
||||||
|
if input_dtype == torch.float32:
|
||||||
|
logger.warning_once(
|
||||||
|
"The input hidden states seems to be silently casted in float32, this might be related to"
|
||||||
|
" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
||||||
|
" float16."
|
||||||
|
)
|
||||||
|
|
||||||
|
query_states = query_states.to(torch.float16)
|
||||||
|
key_states = key_states.to(torch.float16)
|
||||||
|
value_states = value_states.to(torch.float16)
|
||||||
|
|
||||||
|
# Reashape to the expected shape for Flash Attention
|
||||||
|
query_states = query_states.transpose(1, 2)
|
||||||
|
key_states = key_states.transpose(1, 2)
|
||||||
|
value_states = value_states.transpose(1, 2)
|
||||||
|
|
||||||
|
attn_output = self._flash_attention_forward(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
padding_mask,
|
||||||
|
q_len,
|
||||||
|
dropout=dropout_rate,
|
||||||
|
use_sliding_windows=use_sliding_windows,
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
|
||||||
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
|
if not output_attentions:
|
||||||
|
attn_weights = None
|
||||||
|
|
||||||
|
return attn_output, attn_weights, past_key_value
|
||||||
|
|
||||||
|
def _flash_attention_forward(
|
||||||
|
self,
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
padding_mask,
|
||||||
|
query_length,
|
||||||
|
dropout=0.0,
|
||||||
|
softmax_scale=None,
|
||||||
|
use_sliding_windows=False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
|
||||||
|
first unpad the input, then computes the attention scores and pad the final attention scores.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query_states (`torch.Tensor`):
|
||||||
|
Input query states to be passed to Flash Attention API
|
||||||
|
key_states (`torch.Tensor`):
|
||||||
|
Input key states to be passed to Flash Attention API
|
||||||
|
value_states (`torch.Tensor`):
|
||||||
|
Input value states to be passed to Flash Attention API
|
||||||
|
padding_mask (`torch.Tensor`):
|
||||||
|
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
|
||||||
|
position of padding tokens and 1 for the position of non-padding tokens.
|
||||||
|
dropout (`int`, *optional*):
|
||||||
|
Attention dropout
|
||||||
|
softmax_scale (`float`, *optional*):
|
||||||
|
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
|
||||||
|
use_sliding_windows (`bool`, *optional*):
|
||||||
|
Whether to activate sliding window attention.
|
||||||
|
"""
|
||||||
|
# Contains at least one padding token in the sequence
|
||||||
|
if padding_mask is not None:
|
||||||
|
batch_size = query_states.shape[0]
|
||||||
|
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
|
||||||
|
query_states, key_states, value_states, padding_mask, query_length
|
||||||
|
)
|
||||||
|
|
||||||
|
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
||||||
|
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
|
||||||
|
|
||||||
|
if not use_sliding_windows:
|
||||||
|
attn_output_unpad = flash_attn_varlen_func(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
cu_seqlens_q=cu_seqlens_q,
|
||||||
|
cu_seqlens_k=cu_seqlens_k,
|
||||||
|
max_seqlen_q=max_seqlen_in_batch_q,
|
||||||
|
max_seqlen_k=max_seqlen_in_batch_k,
|
||||||
|
dropout_p=dropout,
|
||||||
|
softmax_scale=softmax_scale,
|
||||||
|
causal=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
attn_output_unpad = flash_attn_varlen_func(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
cu_seqlens_q=cu_seqlens_q,
|
||||||
|
cu_seqlens_k=cu_seqlens_k,
|
||||||
|
max_seqlen_q=max_seqlen_in_batch_q,
|
||||||
|
max_seqlen_k=max_seqlen_in_batch_k,
|
||||||
|
dropout_p=dropout,
|
||||||
|
softmax_scale=softmax_scale,
|
||||||
|
causal=True,
|
||||||
|
window_size=(self.config.sliding_window, self.config.sliding_window),
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
|
||||||
|
else:
|
||||||
|
if not use_sliding_windows:
|
||||||
|
attn_output = flash_attn_func(
|
||||||
|
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=True
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
attn_output = flash_attn_func(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
dropout,
|
||||||
|
softmax_scale=softmax_scale,
|
||||||
|
causal=True,
|
||||||
|
window_size=(self.config.sliding_window, self.config.sliding_window),
|
||||||
|
)
|
||||||
|
|
||||||
|
return attn_output
|
||||||
|
|
||||||
|
def _upad_input(self, query_layer, key_layer, value_layer, padding_mask, query_length):
|
||||||
|
batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
|
||||||
|
|
||||||
|
# On the first iteration we need to properly re-create the padding mask
|
||||||
|
# by slicing it on the proper place
|
||||||
|
if kv_seq_len != padding_mask.shape[-1]:
|
||||||
|
padding_mask_num_tokens = padding_mask.shape[-1]
|
||||||
|
padding_mask = padding_mask[:, padding_mask_num_tokens - kv_seq_len :]
|
||||||
|
|
||||||
|
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(padding_mask)
|
||||||
|
|
||||||
|
key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
|
||||||
|
value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
|
||||||
|
|
||||||
|
if query_length == kv_seq_len:
|
||||||
|
query_layer = index_first_axis(
|
||||||
|
query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k
|
||||||
|
)
|
||||||
|
cu_seqlens_q = cu_seqlens_k
|
||||||
|
max_seqlen_in_batch_q = max_seqlen_in_batch_k
|
||||||
|
indices_q = indices_k
|
||||||
|
elif query_length == 1:
|
||||||
|
max_seqlen_in_batch_q = 1
|
||||||
|
cu_seqlens_q = torch.arange(
|
||||||
|
batch_size + 1, dtype=torch.int32, device=query_layer.device
|
||||||
|
) # There is a memcpy here, that is very bad.
|
||||||
|
indices_q = cu_seqlens_q[:-1]
|
||||||
|
query_layer = query_layer.squeeze(1)
|
||||||
|
else:
|
||||||
|
# The -q_len: slice assumes left padding.
|
||||||
|
padding_mask = padding_mask[:, -query_length:]
|
||||||
|
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, padding_mask)
|
||||||
|
|
||||||
|
return (
|
||||||
|
query_layer,
|
||||||
|
key_layer,
|
||||||
|
value_layer,
|
||||||
|
indices_q,
|
||||||
|
(cu_seqlens_q, cu_seqlens_k),
|
||||||
|
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class MistralDecoderLayer(nn.Module):
|
class MistralDecoderLayer(nn.Module):
|
||||||
def __init__(self, config: MistralConfig):
|
def __init__(self, config: MistralConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
self.self_attn = MistralAttention(config=config)
|
self.self_attn = (
|
||||||
|
MistralAttention(config=config)
|
||||||
|
if not getattr(config, "_flash_attn_2_enabled", False)
|
||||||
|
else MistralFlashAttention2(config)
|
||||||
|
)
|
||||||
self.mlp = MistralMLP(config)
|
self.mlp = MistralMLP(config)
|
||||||
self.input_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.input_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
self.post_attention_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.post_attention_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
@@ -308,6 +597,7 @@ class MistralDecoderLayer(nn.Module):
|
|||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
output_attentions: Optional[bool] = False,
|
output_attentions: Optional[bool] = False,
|
||||||
use_cache: Optional[bool] = False,
|
use_cache: Optional[bool] = False,
|
||||||
|
padding_mask: Optional[torch.Tensor] = None,
|
||||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@@ -335,6 +625,7 @@ class MistralDecoderLayer(nn.Module):
|
|||||||
past_key_value=past_key_value,
|
past_key_value=past_key_value,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
|
padding_mask=padding_mask,
|
||||||
)
|
)
|
||||||
hidden_states = residual + hidden_states
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
@@ -382,6 +673,7 @@ class MistralPreTrainedModel(PreTrainedModel):
|
|||||||
supports_gradient_checkpointing = True
|
supports_gradient_checkpointing = True
|
||||||
_no_split_modules = ["MistralDecoderLayer"]
|
_no_split_modules = ["MistralDecoderLayer"]
|
||||||
_skip_keys_device_placement = "past_key_values"
|
_skip_keys_device_placement = "past_key_values"
|
||||||
|
_supports_flash_attn_2 = True
|
||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
std = self.config.initializer_range
|
std = self.config.initializer_range
|
||||||
@@ -569,11 +861,30 @@ class MistralModel(MistralPreTrainedModel):
|
|||||||
|
|
||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
inputs_embeds = self.embed_tokens(input_ids)
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
padding_mask = None
|
||||||
|
|
||||||
# embed positions
|
# embed positions
|
||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
attention_mask = torch.ones(
|
attention_mask = torch.ones(
|
||||||
(batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
|
(batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
|
||||||
)
|
)
|
||||||
|
elif 0 in attention_mask:
|
||||||
|
padding_mask = attention_mask
|
||||||
|
|
||||||
|
if (
|
||||||
|
padding_mask is not None
|
||||||
|
and hasattr(self.config, "_flash_attn_2_enabled")
|
||||||
|
and self.config._flash_attn_2_enabled
|
||||||
|
):
|
||||||
|
is_padding_right = padding_mask[:, -1].sum().item() != batch_size
|
||||||
|
if is_padding_right:
|
||||||
|
raise ValueError(
|
||||||
|
"You are attempting to perform batched generation with padding_side='right'"
|
||||||
|
" this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to "
|
||||||
|
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
|
||||||
|
)
|
||||||
|
|
||||||
attention_mask = self._prepare_decoder_attention_mask(
|
attention_mask = self._prepare_decoder_attention_mask(
|
||||||
attention_mask,
|
attention_mask,
|
||||||
(batch_size, seq_length),
|
(batch_size, seq_length),
|
||||||
@@ -607,7 +918,7 @@ class MistralModel(MistralPreTrainedModel):
|
|||||||
def create_custom_forward(module):
|
def create_custom_forward(module):
|
||||||
def custom_forward(*inputs):
|
def custom_forward(*inputs):
|
||||||
# None for past_key_value
|
# None for past_key_value
|
||||||
return module(*inputs, past_key_value, output_attentions)
|
return module(*inputs, past_key_value, output_attentions, padding_mask=padding_mask)
|
||||||
|
|
||||||
return custom_forward
|
return custom_forward
|
||||||
|
|
||||||
@@ -625,6 +936,7 @@ class MistralModel(MistralPreTrainedModel):
|
|||||||
past_key_value=past_key_value,
|
past_key_value=past_key_value,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
|
padding_mask=padding_mask,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
|||||||
@@ -15,10 +15,13 @@
|
|||||||
""" Testing suite for the PyTorch Mistral model. """
|
""" Testing suite for the PyTorch Mistral model. """
|
||||||
|
|
||||||
|
|
||||||
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
from pytest import mark
|
||||||
|
|
||||||
from transformers import AutoTokenizer, MistralConfig, is_torch_available
|
from transformers import AutoTokenizer, MistralConfig, is_torch_available
|
||||||
from transformers.testing_utils import require_torch, slow, torch_device
|
from transformers.testing_utils import require_flash_attn, require_torch, require_torch_gpu, slow, torch_device
|
||||||
|
|
||||||
from ...generation.test_utils import GenerationTesterMixin
|
from ...generation.test_utils import GenerationTesterMixin
|
||||||
from ...test_configuration_common import ConfigTester
|
from ...test_configuration_common import ConfigTester
|
||||||
@@ -351,6 +354,75 @@ class MistralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
|||||||
def test_past_key_values_format(self):
|
def test_past_key_values_format(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@require_flash_attn
|
||||||
|
@require_torch_gpu
|
||||||
|
@mark.flash_attn_test
|
||||||
|
@slow
|
||||||
|
def test_flash_attn_2_generate_padding_right(self):
|
||||||
|
import torch
|
||||||
|
|
||||||
|
for model_class in self.all_generative_model_classes:
|
||||||
|
if not model_class._supports_flash_attn_2:
|
||||||
|
return
|
||||||
|
|
||||||
|
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
model = model_class(config)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
model.save_pretrained(tmpdirname)
|
||||||
|
model = model_class.from_pretrained(
|
||||||
|
tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=False, low_cpu_mem_usage=True
|
||||||
|
).to(torch_device)
|
||||||
|
|
||||||
|
dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device)
|
||||||
|
dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [1, 1, 1, 0]]).to(torch_device)
|
||||||
|
|
||||||
|
model.generate(dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False)
|
||||||
|
|
||||||
|
model = model_class.from_pretrained(
|
||||||
|
tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=True, low_cpu_mem_usage=True
|
||||||
|
).to(torch_device)
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
_ = model.generate(
|
||||||
|
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
|
||||||
|
)
|
||||||
|
|
||||||
|
@require_flash_attn
|
||||||
|
@require_torch_gpu
|
||||||
|
@mark.flash_attn_test
|
||||||
|
@slow
|
||||||
|
def test_flash_attn_2_inference_padding_right(self):
|
||||||
|
import torch
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
if not model_class._supports_flash_attn_2:
|
||||||
|
return
|
||||||
|
|
||||||
|
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
model = model_class(config)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
model.save_pretrained(tmpdirname)
|
||||||
|
model_fa = model_class.from_pretrained(
|
||||||
|
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=True
|
||||||
|
)
|
||||||
|
model_fa.to(torch_device)
|
||||||
|
|
||||||
|
model = model_class.from_pretrained(
|
||||||
|
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=False
|
||||||
|
)
|
||||||
|
model.to(torch_device)
|
||||||
|
|
||||||
|
dummy_input = torch.LongTensor([[1, 2, 3, 4, 5]]).to(torch_device)
|
||||||
|
dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1, 0]]).to(torch_device)
|
||||||
|
|
||||||
|
_ = model(dummy_input, output_hidden_states=True).hidden_states[-1]
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
_ = model_fa(
|
||||||
|
dummy_input, attention_mask=dummy_attention_mask, output_hidden_states=True
|
||||||
|
).hidden_states[-1]
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class MistralIntegrationTest(unittest.TestCase):
|
class MistralIntegrationTest(unittest.TestCase):
|
||||||
|
|||||||
@@ -2926,7 +2926,7 @@ class ModelTesterMixin:
|
|||||||
model.save_pretrained(tmpdirname)
|
model.save_pretrained(tmpdirname)
|
||||||
|
|
||||||
dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device)
|
dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device)
|
||||||
dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [1, 1, 1, 0]]).to(torch_device)
|
dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [0, 1, 1, 1]]).to(torch_device)
|
||||||
|
|
||||||
model = model_class.from_pretrained(
|
model = model_class.from_pretrained(
|
||||||
tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=True, low_cpu_mem_usage=True
|
tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=True, low_cpu_mem_usage=True
|
||||||
|
|||||||
Reference in New Issue
Block a user