Add Flash Attention 2 to M2M100 model (#30256)
* Added flash attention 2. * Fixes. * Fix inheritance. * Fixed init. * Remove stuff. * Added documentation. * Add FA2 to M2M100 documentation. * Add test. * Fixed documentation. * Update src/transformers/models/m2m_100/modeling_m2m_100.py Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * Update docs/source/en/model_doc/nllb.md Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Fixed variable name. --------- Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
ec92f983af
commit
b65df514d1
@@ -121,3 +121,45 @@ Hindi to French and Chinese to English using the *facebook/m2m100_418M* checkpoi
|
|||||||
|
|
||||||
[[autodoc]] M2M100ForConditionalGeneration
|
[[autodoc]] M2M100ForConditionalGeneration
|
||||||
- forward
|
- forward
|
||||||
|
|
||||||
|
## Using Flash Attention 2
|
||||||
|
|
||||||
|
Flash Attention 2 is a faster, optimized version of the attention scores computation which relies on `cuda` kernels.
|
||||||
|
|
||||||
|
### Installation
|
||||||
|
|
||||||
|
First, check whether your hardware is compatible with Flash Attention 2. The latest list of compatible hardware can be found in the [official documentation](https://github.com/Dao-AILab/flash-attention#installation-and-features).
|
||||||
|
|
||||||
|
Next, [install](https://github.com/Dao-AILab/flash-attention#installation-and-features) the latest version of Flash Attention 2:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install -U flash-attn --no-build-isolation
|
||||||
|
```
|
||||||
|
|
||||||
|
### Usage
|
||||||
|
|
||||||
|
To load a model using Flash Attention 2, we can pass the argument `attn_implementation="flash_attention_2"` to [`.from_pretrained`](https://huggingface.co/docs/transformers/main/en/main_classes/model#transformers.PreTrainedModel.from_pretrained). You can use either `torch.float16` or `torch.bfloat16` precision.
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> import torch
|
||||||
|
>>> from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
|
||||||
|
|
||||||
|
>>> model = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_418M", torch_dtype=torch.float16, attn_implementation="flash_attention_2").to("cuda").eval()
|
||||||
|
>>> tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_418M")
|
||||||
|
|
||||||
|
>>> # translate Hindi to French
|
||||||
|
>>> hi_text = "जीवन एक चॉकलेट बॉक्स की तरह है।"
|
||||||
|
>>> tokenizer.src_lang = "hi"
|
||||||
|
>>> encoded_hi = tokenizer(hi_text, return_tensors="pt").to("cuda")
|
||||||
|
>>> generated_tokens = model.generate(**encoded_hi, forced_bos_token_id=tokenizer.get_lang_id("fr"))
|
||||||
|
>>> tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
|
||||||
|
"La vie est comme une boîte de chocolat."
|
||||||
|
```
|
||||||
|
|
||||||
|
### Expected speedups
|
||||||
|
|
||||||
|
Below is an expected speedup diagram that compares pure inference time between the native implementation and the Flash Attention 2.
|
||||||
|
|
||||||
|
<div style="text-align: center">
|
||||||
|
<img src="https://huggingface.co/datasets/visheratin/documentation-images/resolve/main/nllb-speedup.webp">
|
||||||
|
</div>
|
||||||
|
|||||||
@@ -145,3 +145,46 @@ UN-Chef sagt, es gibt keine militärische Lösung in Syrien
|
|||||||
## NllbTokenizerFast
|
## NllbTokenizerFast
|
||||||
|
|
||||||
[[autodoc]] NllbTokenizerFast
|
[[autodoc]] NllbTokenizerFast
|
||||||
|
|
||||||
|
## Using Flash Attention 2
|
||||||
|
|
||||||
|
Flash Attention 2 is a faster, optimized version of the attention scores computation which relies on `cuda` kernels.
|
||||||
|
|
||||||
|
### Installation
|
||||||
|
|
||||||
|
First, check whether your hardware is compatible with Flash Attention 2. The latest list of compatible hardware can be found in the [official documentation](https://github.com/Dao-AILab/flash-attention#installation-and-features).
|
||||||
|
|
||||||
|
Next, [install](https://github.com/Dao-AILab/flash-attention#installation-and-features) the latest version of Flash Attention 2:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install -U flash-attn --no-build-isolation
|
||||||
|
```
|
||||||
|
|
||||||
|
### Usage
|
||||||
|
|
||||||
|
To load a model using Flash Attention 2, we can pass the argument `attn_implementation="flash_attention_2"` to [`.from_pretrained`](https://huggingface.co/docs/transformers/main/en/main_classes/model#transformers.PreTrainedModel.from_pretrained). You can use either `torch.float16` or `torch.bfloat16` precision.
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> import torch
|
||||||
|
>>> from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
||||||
|
|
||||||
|
>>> model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M", torch_dtype=torch.float16, attn_implementation="flash_attention_2").to("cuda").eval()
|
||||||
|
>>> tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")
|
||||||
|
|
||||||
|
>>> article = "Şeful ONU spune că nu există o soluţie militară în Siria"
|
||||||
|
>>> inputs = tokenizer(article, return_tensors="pt").to("cuda")
|
||||||
|
|
||||||
|
>>> translated_tokens = model.generate(
|
||||||
|
... **inputs, forced_bos_token_id=tokenizer.lang_code_to_id["deu_Latn"], max_length=30
|
||||||
|
... )
|
||||||
|
>>> tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
|
||||||
|
"UN-Chef sagt, es gibt keine militärische Lösung in Syrien"
|
||||||
|
```
|
||||||
|
|
||||||
|
### Expected speedups
|
||||||
|
|
||||||
|
Below is an expected speedup diagram that compares pure inference time between the native implementation and the Flash Attention 2.
|
||||||
|
|
||||||
|
<div style="text-align: center">
|
||||||
|
<img src="https://huggingface.co/datasets/visheratin/documentation-images/resolve/main/nllb-speedup.webp">
|
||||||
|
</div>
|
||||||
@@ -53,11 +53,13 @@ FlashAttention-2 is currently supported for the following architectures:
|
|||||||
* [Llava](https://huggingface.co/docs/transformers/model_doc/llava)
|
* [Llava](https://huggingface.co/docs/transformers/model_doc/llava)
|
||||||
* [Llava-NeXT](https://huggingface.co/docs/transformers/model_doc/llava_next)
|
* [Llava-NeXT](https://huggingface.co/docs/transformers/model_doc/llava_next)
|
||||||
* [VipLlava](https://huggingface.co/docs/transformers/model_doc/vipllava)
|
* [VipLlava](https://huggingface.co/docs/transformers/model_doc/vipllava)
|
||||||
|
* [M2M100](https://huggingface.co/docs/transformers/model_doc/m2m_100)
|
||||||
* [MBart](https://huggingface.co/docs/transformers/model_doc/mbart#transformers.MBartModel)
|
* [MBart](https://huggingface.co/docs/transformers/model_doc/mbart#transformers.MBartModel)
|
||||||
* [Mistral](https://huggingface.co/docs/transformers/model_doc/mistral#transformers.MistralModel)
|
* [Mistral](https://huggingface.co/docs/transformers/model_doc/mistral#transformers.MistralModel)
|
||||||
* [Mixtral](https://huggingface.co/docs/transformers/model_doc/mixtral#transformers.MixtralModel)
|
* [Mixtral](https://huggingface.co/docs/transformers/model_doc/mixtral#transformers.MixtralModel)
|
||||||
* [Musicgen](https://huggingface.co/docs/transformers/model_doc/musicgen#transformers.MusicgenModel)
|
* [Musicgen](https://huggingface.co/docs/transformers/model_doc/musicgen#transformers.MusicgenModel)
|
||||||
* [MusicGen Melody](https://huggingface.co/docs/transformers/model_doc/musicgen_melody#transformers.MusicgenMelodyModel)
|
* [MusicGen Melody](https://huggingface.co/docs/transformers/model_doc/musicgen_melody#transformers.MusicgenMelodyModel)
|
||||||
|
* [NLLB](https://huggingface.co/docs/transformers/model_doc/nllb)
|
||||||
* [OLMo](https://huggingface.co/docs/transformers/model_doc/olmo#transformers.OlmoModel)
|
* [OLMo](https://huggingface.co/docs/transformers/model_doc/olmo#transformers.OlmoModel)
|
||||||
* [OPT](https://huggingface.co/docs/transformers/model_doc/opt#transformers.OPTModel)
|
* [OPT](https://huggingface.co/docs/transformers/model_doc/opt#transformers.OPTModel)
|
||||||
* [Phi](https://huggingface.co/docs/transformers/model_doc/phi#transformers.PhiModel)
|
* [Phi](https://huggingface.co/docs/transformers/model_doc/phi#transformers.PhiModel)
|
||||||
|
|||||||
@@ -12,13 +12,13 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" PyTorch M2M100 model."""
|
"""PyTorch M2M100 model."""
|
||||||
|
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import CrossEntropyLoss
|
from torch.nn import CrossEntropyLoss
|
||||||
|
|
||||||
@@ -37,12 +37,19 @@ from ...utils import (
|
|||||||
add_end_docstrings,
|
add_end_docstrings,
|
||||||
add_start_docstrings,
|
add_start_docstrings,
|
||||||
add_start_docstrings_to_model_forward,
|
add_start_docstrings_to_model_forward,
|
||||||
|
is_flash_attn_2_available,
|
||||||
|
is_flash_attn_greater_or_equal_2_10,
|
||||||
logging,
|
logging,
|
||||||
replace_return_docstrings,
|
replace_return_docstrings,
|
||||||
)
|
)
|
||||||
from .configuration_m2m_100 import M2M100Config
|
from .configuration_m2m_100 import M2M100Config
|
||||||
|
|
||||||
|
|
||||||
|
if is_flash_attn_2_available():
|
||||||
|
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
||||||
|
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
_CONFIG_FOR_DOC = "M2M100Config"
|
_CONFIG_FOR_DOC = "M2M100Config"
|
||||||
@@ -317,6 +324,208 @@ class M2M100Attention(nn.Module):
|
|||||||
return attn_output, attn_weights_reshaped, past_key_value
|
return attn_output, attn_weights_reshaped, past_key_value
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
|
||||||
|
def _get_unpad_data(attention_mask):
|
||||||
|
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
||||||
|
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
||||||
|
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
||||||
|
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
|
||||||
|
return (
|
||||||
|
indices,
|
||||||
|
cu_seqlens,
|
||||||
|
max_seqlen_in_batch,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class M2M100FlashAttention2(M2M100Attention):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embed_dim: int,
|
||||||
|
num_heads: int,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
is_decoder: bool = False,
|
||||||
|
bias: bool = True,
|
||||||
|
is_causal: bool = False,
|
||||||
|
config: Optional[M2M100Config] = None,
|
||||||
|
):
|
||||||
|
super().__init__(embed_dim, num_heads, dropout, is_decoder, bias, is_causal, config)
|
||||||
|
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
||||||
|
|
||||||
|
def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||||
|
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
key_value_states: Optional[torch.Tensor] = None,
|
||||||
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
layer_head_mask: Optional[torch.Tensor] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
"""Input shape: Batch x Time x Channel"""
|
||||||
|
|
||||||
|
# if key_value_states are provided this layer is used as a cross-attention layer
|
||||||
|
# for the decoder
|
||||||
|
is_cross_attention = key_value_states is not None
|
||||||
|
|
||||||
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
# get query proj
|
||||||
|
query_states = self._reshape(self.q_proj(hidden_states), -1, bsz)
|
||||||
|
# get key, value proj
|
||||||
|
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
|
||||||
|
# is checking that the `sequence_length` of the `past_key_value` is the same as
|
||||||
|
# the provided `key_value_states` to support prefix tuning
|
||||||
|
if (
|
||||||
|
is_cross_attention
|
||||||
|
and past_key_value is not None
|
||||||
|
and past_key_value[0].shape[2] == key_value_states.shape[1]
|
||||||
|
):
|
||||||
|
# reuse k,v, cross_attentions
|
||||||
|
key_states = past_key_value[0].transpose(1, 2)
|
||||||
|
value_states = past_key_value[1].transpose(1, 2)
|
||||||
|
elif is_cross_attention:
|
||||||
|
# cross_attentions
|
||||||
|
key_states = self._reshape(self.k_proj(key_value_states), -1, bsz)
|
||||||
|
value_states = self._reshape(self.v_proj(key_value_states), -1, bsz)
|
||||||
|
elif past_key_value is not None:
|
||||||
|
# reuse k, v, self_attention
|
||||||
|
key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
|
||||||
|
value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
|
||||||
|
key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1)
|
||||||
|
value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1)
|
||||||
|
else:
|
||||||
|
# self_attention
|
||||||
|
key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
|
||||||
|
value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
|
||||||
|
|
||||||
|
if self.is_decoder:
|
||||||
|
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||||
|
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||||
|
# key/value_states (first "if" case)
|
||||||
|
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
||||||
|
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||||
|
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||||
|
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||||
|
past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2))
|
||||||
|
|
||||||
|
kv_seq_len = key_states.shape[-2]
|
||||||
|
if past_key_value is not None:
|
||||||
|
kv_seq_len += past_key_value[0].shape[-2]
|
||||||
|
|
||||||
|
attn_output = self._flash_attention_forward(
|
||||||
|
query_states, key_states, value_states, attention_mask, q_len, dropout=self.dropout, softmax_scale=None
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
|
||||||
|
# partitioned across GPUs when using tensor-parallelism.
|
||||||
|
attn_output = attn_output.reshape(bsz, q_len, self.embed_dim)
|
||||||
|
|
||||||
|
attn_output = self.out_proj(attn_output)
|
||||||
|
|
||||||
|
return attn_output, None, past_key_value
|
||||||
|
|
||||||
|
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward
|
||||||
|
def _flash_attention_forward(
|
||||||
|
self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
|
||||||
|
first unpad the input, then computes the attention scores and pad the final attention scores.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query_states (`torch.Tensor`):
|
||||||
|
Input query states to be passed to Flash Attention API
|
||||||
|
key_states (`torch.Tensor`):
|
||||||
|
Input key states to be passed to Flash Attention API
|
||||||
|
value_states (`torch.Tensor`):
|
||||||
|
Input value states to be passed to Flash Attention API
|
||||||
|
attention_mask (`torch.Tensor`):
|
||||||
|
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
|
||||||
|
position of padding tokens and 1 for the position of non-padding tokens.
|
||||||
|
dropout (`float`):
|
||||||
|
Attention dropout
|
||||||
|
softmax_scale (`float`, *optional*):
|
||||||
|
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
|
||||||
|
"""
|
||||||
|
if not self._flash_attn_uses_top_left_mask:
|
||||||
|
causal = self.is_causal
|
||||||
|
else:
|
||||||
|
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
|
||||||
|
causal = self.is_causal and query_length != 1
|
||||||
|
|
||||||
|
# Contains at least one padding token in the sequence
|
||||||
|
if attention_mask is not None:
|
||||||
|
batch_size = query_states.shape[0]
|
||||||
|
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
|
||||||
|
query_states, key_states, value_states, attention_mask, query_length
|
||||||
|
)
|
||||||
|
|
||||||
|
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
||||||
|
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
|
||||||
|
|
||||||
|
attn_output_unpad = flash_attn_varlen_func(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
cu_seqlens_q=cu_seqlens_q,
|
||||||
|
cu_seqlens_k=cu_seqlens_k,
|
||||||
|
max_seqlen_q=max_seqlen_in_batch_q,
|
||||||
|
max_seqlen_k=max_seqlen_in_batch_k,
|
||||||
|
dropout_p=dropout,
|
||||||
|
softmax_scale=softmax_scale,
|
||||||
|
causal=causal,
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
|
||||||
|
else:
|
||||||
|
attn_output = flash_attn_func(
|
||||||
|
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
|
||||||
|
)
|
||||||
|
|
||||||
|
return attn_output
|
||||||
|
|
||||||
|
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
|
||||||
|
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
|
||||||
|
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
|
||||||
|
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
|
||||||
|
|
||||||
|
key_layer = index_first_axis(
|
||||||
|
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
|
||||||
|
)
|
||||||
|
value_layer = index_first_axis(
|
||||||
|
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
|
||||||
|
)
|
||||||
|
if query_length == kv_seq_len:
|
||||||
|
query_layer = index_first_axis(
|
||||||
|
query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
|
||||||
|
)
|
||||||
|
cu_seqlens_q = cu_seqlens_k
|
||||||
|
max_seqlen_in_batch_q = max_seqlen_in_batch_k
|
||||||
|
indices_q = indices_k
|
||||||
|
elif query_length == 1:
|
||||||
|
max_seqlen_in_batch_q = 1
|
||||||
|
cu_seqlens_q = torch.arange(
|
||||||
|
batch_size + 1, dtype=torch.int32, device=query_layer.device
|
||||||
|
) # There is a memcpy here, that is very bad.
|
||||||
|
indices_q = cu_seqlens_q[:-1]
|
||||||
|
query_layer = query_layer.squeeze(1)
|
||||||
|
else:
|
||||||
|
# The -q_len: slice assumes left padding.
|
||||||
|
attention_mask = attention_mask[:, -query_length:]
|
||||||
|
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
|
||||||
|
|
||||||
|
return (
|
||||||
|
query_layer,
|
||||||
|
key_layer,
|
||||||
|
value_layer,
|
||||||
|
indices_q,
|
||||||
|
(cu_seqlens_q, cu_seqlens_k),
|
||||||
|
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->M2M100, MBART->M2M100
|
# Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->M2M100, MBART->M2M100
|
||||||
class M2M100EncoderLayer(nn.Module):
|
class M2M100EncoderLayer(nn.Module):
|
||||||
def __init__(self, config: M2M100Config):
|
def __init__(self, config: M2M100Config):
|
||||||
@@ -388,7 +597,10 @@ class M2M100EncoderLayer(nn.Module):
|
|||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
M2M100_ATTENTION_CLASSES = {"eager": M2M100Attention}
|
M2M100_ATTENTION_CLASSES = {
|
||||||
|
"eager": M2M100Attention,
|
||||||
|
"flash_attention_2": M2M100FlashAttention2,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->M2M100, MBART->M2M100
|
# Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->M2M100, MBART->M2M100
|
||||||
@@ -517,6 +729,7 @@ class M2M100PreTrainedModel(PreTrainedModel):
|
|||||||
base_model_prefix = "model"
|
base_model_prefix = "model"
|
||||||
supports_gradient_checkpointing = True
|
supports_gradient_checkpointing = True
|
||||||
_no_split_modules = ["M2M100Attention"]
|
_no_split_modules = ["M2M100Attention"]
|
||||||
|
_supports_flash_attn_2 = True
|
||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
std = self.config.init_std
|
std = self.config.init_std
|
||||||
@@ -687,6 +900,7 @@ class M2M100Encoder(M2M100PreTrainedModel):
|
|||||||
)
|
)
|
||||||
self.layers = nn.ModuleList([M2M100EncoderLayer(config) for _ in range(config.encoder_layers)])
|
self.layers = nn.ModuleList([M2M100EncoderLayer(config) for _ in range(config.encoder_layers)])
|
||||||
self.layer_norm = nn.LayerNorm(config.d_model)
|
self.layer_norm = nn.LayerNorm(config.d_model)
|
||||||
|
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
|
||||||
|
|
||||||
self.gradient_checkpointing = False
|
self.gradient_checkpointing = False
|
||||||
# Initialize weights and apply final processing
|
# Initialize weights and apply final processing
|
||||||
@@ -767,6 +981,9 @@ class M2M100Encoder(M2M100PreTrainedModel):
|
|||||||
|
|
||||||
# expand attention_mask
|
# expand attention_mask
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
|
if self._use_flash_attention_2:
|
||||||
|
attention_mask = attention_mask if 0 in attention_mask else None
|
||||||
|
else:
|
||||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||||
attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
|
attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
|
||||||
|
|
||||||
@@ -857,6 +1074,7 @@ class M2M100Decoder(M2M100PreTrainedModel):
|
|||||||
self.padding_idx,
|
self.padding_idx,
|
||||||
)
|
)
|
||||||
self.layers = nn.ModuleList([M2M100DecoderLayer(config) for _ in range(config.decoder_layers)])
|
self.layers = nn.ModuleList([M2M100DecoderLayer(config) for _ in range(config.decoder_layers)])
|
||||||
|
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
|
||||||
self.layer_norm = nn.LayerNorm(config.d_model)
|
self.layer_norm = nn.LayerNorm(config.d_model)
|
||||||
|
|
||||||
self.gradient_checkpointing = False
|
self.gradient_checkpointing = False
|
||||||
@@ -967,14 +1185,20 @@ class M2M100Decoder(M2M100PreTrainedModel):
|
|||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
||||||
|
|
||||||
# create causal mask
|
if self._use_flash_attention_2:
|
||||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
# 2d mask is passed through the layers
|
||||||
|
combined_attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
||||||
|
else:
|
||||||
|
# 4d mask is passed through the layers
|
||||||
combined_attention_mask = _prepare_4d_causal_attention_mask(
|
combined_attention_mask = _prepare_4d_causal_attention_mask(
|
||||||
attention_mask, input_shape, inputs_embeds, past_key_values_length
|
attention_mask, input_shape, inputs_embeds, past_key_values_length
|
||||||
)
|
)
|
||||||
|
|
||||||
# expand encoder attention mask
|
# expand encoder attention mask
|
||||||
if encoder_hidden_states is not None and encoder_attention_mask is not None:
|
if encoder_hidden_states is not None and encoder_attention_mask is not None:
|
||||||
|
if self._use_flash_attention_2:
|
||||||
|
encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
|
||||||
|
else:
|
||||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||||
encoder_attention_mask = _prepare_4d_attention_mask(
|
encoder_attention_mask = _prepare_4d_attention_mask(
|
||||||
encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
|
encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
|
||||||
@@ -1102,6 +1326,11 @@ class M2M100Model(M2M100PreTrainedModel):
|
|||||||
self.encoder = M2M100Encoder(config, self.shared)
|
self.encoder = M2M100Encoder(config, self.shared)
|
||||||
self.decoder = M2M100Decoder(config, self.shared)
|
self.decoder = M2M100Decoder(config, self.shared)
|
||||||
|
|
||||||
|
if config._attn_implementation == "flash_attention_2":
|
||||||
|
logger.warning_once(
|
||||||
|
"Attention with Flash Attention 2 does not support `layer_head_mask`. If you need this feature, please use standard attention."
|
||||||
|
)
|
||||||
|
|
||||||
# Initialize weights and apply final processing
|
# Initialize weights and apply final processing
|
||||||
self.post_init()
|
self.post_init()
|
||||||
|
|
||||||
|
|||||||
@@ -19,12 +19,16 @@ import copy
|
|||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from transformers import M2M100Config, is_torch_available
|
from transformers import M2M100Config, is_torch_available
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
|
require_flash_attn,
|
||||||
require_sentencepiece,
|
require_sentencepiece,
|
||||||
require_tokenizers,
|
require_tokenizers,
|
||||||
require_torch,
|
require_torch,
|
||||||
require_torch_fp16,
|
require_torch_fp16,
|
||||||
|
require_torch_gpu,
|
||||||
slow,
|
slow,
|
||||||
torch_device,
|
torch_device,
|
||||||
)
|
)
|
||||||
@@ -412,3 +416,48 @@ class M2M100ModelIntegrationTests(unittest.TestCase):
|
|||||||
hypotheses_batch.tolist(), clean_up_tokenization_spaces=True, skip_special_tokens=True
|
hypotheses_batch.tolist(), clean_up_tokenization_spaces=True, skip_special_tokens=True
|
||||||
)
|
)
|
||||||
assert generated == expected_en
|
assert generated == expected_en
|
||||||
|
|
||||||
|
@require_flash_attn
|
||||||
|
@require_torch_gpu
|
||||||
|
@pytest.mark.flash_attn_test
|
||||||
|
@slow
|
||||||
|
def test_flash_attn_2_seq_to_seq_generation(self):
|
||||||
|
"""
|
||||||
|
Overwritting the common test as the test is flaky on tiny models
|
||||||
|
"""
|
||||||
|
model = M2M100ForConditionalGeneration.from_pretrained(
|
||||||
|
"facebook/m2m100_418M", attn_implementation="flash_attention_2"
|
||||||
|
).to(torch_device)
|
||||||
|
|
||||||
|
tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_418M", src_lang="fr", tgt_lang="en")
|
||||||
|
|
||||||
|
src_fr = [
|
||||||
|
"L'affaire NSA souligne l'absence totale de débat sur le renseignement",
|
||||||
|
"Selon moi, il y a deux niveaux de réponse de la part du gouvernement français.",
|
||||||
|
"Lorsque François Hollande téléphone à Barack Obama ou quand le ministre des affaires étrangères Laurent"
|
||||||
|
" Fabius convoque l'ambassadeur des Etats-Unis, ils réagissent à une vraie découverte, qui est celle de"
|
||||||
|
" l'ampleur de la surveillance américaine sur l'ensemble des communications en France.",
|
||||||
|
]
|
||||||
|
|
||||||
|
# The below article tests that we don't add any hypotheses outside of the top n_beams
|
||||||
|
dct = tokenizer(src_fr, padding=True, return_tensors="pt")
|
||||||
|
|
||||||
|
hypotheses_batch = model.generate(
|
||||||
|
input_ids=dct["input_ids"].to(torch_device),
|
||||||
|
attention_mask=dct["attention_mask"].to(torch_device),
|
||||||
|
num_beams=5,
|
||||||
|
forced_bos_token_id=tokenizer.get_lang_id("en"),
|
||||||
|
)
|
||||||
|
|
||||||
|
expected_en = [
|
||||||
|
"The NSA case highlights the total absence of intelligence debate",
|
||||||
|
"I think there are two levels of response from the French government.",
|
||||||
|
"When François Hollande calls Barack Obama or when Foreign Minister Laurent Fabius calls the U.S."
|
||||||
|
" Ambassador, they respond to a real discovery, which is that of the scale of U.S. surveillance on all"
|
||||||
|
" communications in France.",
|
||||||
|
]
|
||||||
|
|
||||||
|
generated = tokenizer.batch_decode(
|
||||||
|
hypotheses_batch.tolist(), clean_up_tokenization_spaces=True, skip_special_tokens=True
|
||||||
|
)
|
||||||
|
assert generated == expected_en
|
||||||
|
|||||||
Reference in New Issue
Block a user