Fix FA2 integration (#28142)
* fix fa2 * fix FA2 for popular models * improve warning and add Younes as co-author Co-Authored-By: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * Update src/transformers/modeling_utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * fix the warning * Add Tip * typo fix * nit --------- Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
b134f6857e
commit
def581ef51
@@ -67,6 +67,8 @@ come in several checkpoints they each contain a part of each weight of the model
|
|||||||
|
|
||||||
- The LLaMA tokenizer is a BPE model based on [sentencepiece](https://github.com/google/sentencepiece). One quirk of sentencepiece is that when decoding a sequence, if the first token is the start of the word (e.g. "Banana"), the tokenizer does not prepend the prefix space to the string.
|
- The LLaMA tokenizer is a BPE model based on [sentencepiece](https://github.com/google/sentencepiece). One quirk of sentencepiece is that when decoding a sequence, if the first token is the start of the word (e.g. "Banana"), the tokenizer does not prepend the prefix space to the string.
|
||||||
|
|
||||||
|
- When using Flash Attention 2 via `attn_implementation="flash_attention_2"`, don't pass `torch_dtype` to the `from_pretrained` class method and use Automatic Mixed-Precision training. When using `Trainer`, it is simply specifying either `fp16` or `bf16` to `True`. Otherwise, make sure you are using `torch.autocast`. This is required because the Flash Attention only support `fp16` and `bf16` data type.
|
||||||
|
|
||||||
|
|
||||||
## Resources
|
## Resources
|
||||||
|
|
||||||
|
|||||||
@@ -1419,9 +1419,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
"You are attempting to use Flash Attention 2.0 without specifying a torch dtype. This might lead to unexpected behaviour"
|
"You are attempting to use Flash Attention 2.0 without specifying a torch dtype. This might lead to unexpected behaviour"
|
||||||
)
|
)
|
||||||
elif torch_dtype is not None and torch_dtype not in [torch.float16, torch.bfloat16]:
|
elif torch_dtype is not None and torch_dtype not in [torch.float16, torch.bfloat16]:
|
||||||
raise ValueError(
|
logger.warning(
|
||||||
f"Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes. You passed {torch_dtype}, this might lead to"
|
"Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes. "
|
||||||
" unexpected behaviour."
|
"No dtype was provided, you should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator."
|
||||||
)
|
)
|
||||||
|
|
||||||
# The check `torch.empty(0).device.type != "cuda"` is needed as the model may be initialized after `torch.set_default_device` has been called,
|
# The check `torch.empty(0).device.type != "cuda"` is needed as the model may be initialized after `torch.set_default_device` has been called,
|
||||||
|
|||||||
@@ -620,6 +620,8 @@ class FalconFlashAttention2(FalconAttention):
|
|||||||
# Handle the case where the model is quantized
|
# Handle the case where the model is quantized
|
||||||
if hasattr(self.config, "_pre_quantization_dtype"):
|
if hasattr(self.config, "_pre_quantization_dtype"):
|
||||||
target_dtype = self.config._pre_quantization_dtype
|
target_dtype = self.config._pre_quantization_dtype
|
||||||
|
elif torch.is_autocast_enabled():
|
||||||
|
target_dtype = torch.get_autocast_gpu_dtype()
|
||||||
else:
|
else:
|
||||||
target_dtype = self.query_key_value.weight.dtype
|
target_dtype = self.query_key_value.weight.dtype
|
||||||
|
|
||||||
|
|||||||
@@ -378,6 +378,8 @@ class GPTBigCodeFlashAttention2(GPTBigCodeAttention):
|
|||||||
# Handle the case where the model is quantized
|
# Handle the case where the model is quantized
|
||||||
if hasattr(self.config, "_pre_quantization_dtype"):
|
if hasattr(self.config, "_pre_quantization_dtype"):
|
||||||
target_dtype = self.config._pre_quantization_dtype
|
target_dtype = self.config._pre_quantization_dtype
|
||||||
|
elif torch.is_autocast_enabled():
|
||||||
|
target_dtype = torch.get_autocast_gpu_dtype()
|
||||||
else:
|
else:
|
||||||
target_dtype = self.c_attn.weight.dtype
|
target_dtype = self.c_attn.weight.dtype
|
||||||
|
|
||||||
|
|||||||
@@ -531,6 +531,8 @@ class LlamaFlashAttention2(LlamaAttention):
|
|||||||
# Handle the case where the model is quantized
|
# Handle the case where the model is quantized
|
||||||
if hasattr(self.config, "_pre_quantization_dtype"):
|
if hasattr(self.config, "_pre_quantization_dtype"):
|
||||||
target_dtype = self.config._pre_quantization_dtype
|
target_dtype = self.config._pre_quantization_dtype
|
||||||
|
elif torch.is_autocast_enabled():
|
||||||
|
target_dtype = torch.get_autocast_gpu_dtype()
|
||||||
else:
|
else:
|
||||||
target_dtype = self.q_proj.weight.dtype
|
target_dtype = self.q_proj.weight.dtype
|
||||||
|
|
||||||
|
|||||||
@@ -431,6 +431,8 @@ class MistralFlashAttention2(MistralAttention):
|
|||||||
# Handle the case where the model is quantized
|
# Handle the case where the model is quantized
|
||||||
if hasattr(self.config, "_pre_quantization_dtype"):
|
if hasattr(self.config, "_pre_quantization_dtype"):
|
||||||
target_dtype = self.config._pre_quantization_dtype
|
target_dtype = self.config._pre_quantization_dtype
|
||||||
|
elif torch.is_autocast_enabled():
|
||||||
|
target_dtype = torch.get_autocast_gpu_dtype()
|
||||||
else:
|
else:
|
||||||
target_dtype = self.q_proj.weight.dtype
|
target_dtype = self.q_proj.weight.dtype
|
||||||
|
|
||||||
|
|||||||
@@ -479,6 +479,8 @@ class MixtralFlashAttention2(MixtralAttention):
|
|||||||
# Handle the case where the model is quantized
|
# Handle the case where the model is quantized
|
||||||
if hasattr(self.config, "_pre_quantization_dtype"):
|
if hasattr(self.config, "_pre_quantization_dtype"):
|
||||||
target_dtype = self.config._pre_quantization_dtype
|
target_dtype = self.config._pre_quantization_dtype
|
||||||
|
elif torch.is_autocast_enabled():
|
||||||
|
target_dtype = torch.get_autocast_gpu_dtype()
|
||||||
else:
|
else:
|
||||||
target_dtype = self.q_proj.weight.dtype
|
target_dtype = self.q_proj.weight.dtype
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user