[FA-2] Add fa2 support for from_config (#26914)
* add fa2 support for from_config * Update test_modeling_common.py
This commit is contained in:
@@ -1173,14 +1173,20 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
Args:
|
Args:
|
||||||
torch_dtype (`torch.dtype`, *optional*):
|
torch_dtype (`torch.dtype`, *optional*):
|
||||||
Override the default `torch.dtype` and load the model under this dtype.
|
Override the default `torch.dtype` and load the model under this dtype.
|
||||||
|
use_flash_attention_2 (`bool`, *optional*):
|
||||||
|
Whether to load the model with Flash Attention 2 modules.
|
||||||
"""
|
"""
|
||||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||||
|
use_flash_attention_2 = kwargs.pop("use_flash_attention_2", False)
|
||||||
|
|
||||||
# override default dtype if needed
|
# override default dtype if needed
|
||||||
dtype_orig = None
|
dtype_orig = None
|
||||||
if torch_dtype is not None:
|
if torch_dtype is not None:
|
||||||
dtype_orig = cls._set_default_torch_dtype(torch_dtype)
|
dtype_orig = cls._set_default_torch_dtype(torch_dtype)
|
||||||
|
|
||||||
|
if use_flash_attention_2:
|
||||||
|
config = cls._check_and_enable_flash_attn_2(config, torch_dtype)
|
||||||
|
|
||||||
if is_deepspeed_zero3_enabled():
|
if is_deepspeed_zero3_enabled():
|
||||||
import deepspeed
|
import deepspeed
|
||||||
|
|
||||||
|
|||||||
@@ -33,6 +33,7 @@ from pytest import mark
|
|||||||
import transformers
|
import transformers
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoModel,
|
AutoModel,
|
||||||
|
AutoModelForCausalLM,
|
||||||
AutoModelForSequenceClassification,
|
AutoModelForSequenceClassification,
|
||||||
PretrainedConfig,
|
PretrainedConfig,
|
||||||
is_torch_available,
|
is_torch_available,
|
||||||
@@ -3269,6 +3270,53 @@ class ModelTesterMixin:
|
|||||||
# Check models are equal
|
# Check models are equal
|
||||||
self.assertTrue(check_models_equal(flax_model_1, flax_model_2))
|
self.assertTrue(check_models_equal(flax_model_1, flax_model_2))
|
||||||
|
|
||||||
|
@require_flash_attn
|
||||||
|
@require_torch_gpu
|
||||||
|
@mark.flash_attn_test
|
||||||
|
@slow
|
||||||
|
def test_flash_attn_2_from_config(self):
|
||||||
|
import torch
|
||||||
|
|
||||||
|
for model_class in self.all_generative_model_classes:
|
||||||
|
if not model_class._supports_flash_attn_2:
|
||||||
|
return
|
||||||
|
|
||||||
|
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
# TODO: to change it in the future with other relevant auto classes
|
||||||
|
fa2_model = AutoModelForCausalLM.from_config(
|
||||||
|
config, use_flash_attention_2=True, torch_dtype=torch.bfloat16
|
||||||
|
).to(torch_device)
|
||||||
|
|
||||||
|
dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device)
|
||||||
|
dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [0, 1, 1, 1]]).to(torch_device)
|
||||||
|
|
||||||
|
fa2_correctly_converted = False
|
||||||
|
|
||||||
|
for _, module in fa2_model.named_modules():
|
||||||
|
if "FlashAttention" in module.__class__.__name__:
|
||||||
|
fa2_correctly_converted = True
|
||||||
|
break
|
||||||
|
|
||||||
|
self.assertTrue(fa2_correctly_converted)
|
||||||
|
|
||||||
|
_ = fa2_model(input_ids=dummy_input, attention_mask=dummy_attention_mask)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
fa2_model.save_pretrained(tmpdirname)
|
||||||
|
|
||||||
|
model_from_pretrained = AutoModelForCausalLM.from_pretrained(tmpdirname)
|
||||||
|
|
||||||
|
self.assertFalse(getattr(model_from_pretrained.config, "_flash_attn_2_enabled", False))
|
||||||
|
|
||||||
|
fa2_correctly_converted = False
|
||||||
|
|
||||||
|
for _, module in model_from_pretrained.named_modules():
|
||||||
|
if "FlashAttention" in module.__class__.__name__:
|
||||||
|
fa2_correctly_converted = True
|
||||||
|
break
|
||||||
|
|
||||||
|
self.assertFalse(fa2_correctly_converted)
|
||||||
|
|
||||||
|
|
||||||
global_rng = random.Random()
|
global_rng = random.Random()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user