[ESM] Add flash-attention-2 backend for ESM-2 (#38023)
* Add flash-attention-2 backend for ESM-2 Signed-off-by: Peter St. John <pstjohn@nvidia.com> * update extended_attention_mask for fa2 Signed-off-by: Peter St. John <pstjohn@nvidia.com> * add test_flash_attn_2_equivalence test Signed-off-by: Peter St. John <pstjohn@nvidia.com> --------- Signed-off-by: Peter St. John <pstjohn@nvidia.com>
This commit is contained in:
@@ -13,10 +13,22 @@
|
||||
# limitations under the License.
|
||||
"""Testing suite for the PyTorch ESM model."""
|
||||
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
from transformers import EsmConfig, is_torch_available
|
||||
from transformers.testing_utils import TestCasePlus, require_bitsandbytes, require_torch, slow, torch_device
|
||||
from transformers.testing_utils import (
|
||||
TestCasePlus,
|
||||
is_flaky,
|
||||
require_bitsandbytes,
|
||||
require_flash_attn,
|
||||
require_torch,
|
||||
require_torch_gpu,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
|
||||
@@ -59,6 +71,7 @@ class EsmModelTester:
|
||||
num_labels=3,
|
||||
num_choices=4,
|
||||
scope=None,
|
||||
position_embedding_type="rotary",
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
@@ -82,6 +95,7 @@ class EsmModelTester:
|
||||
self.num_labels = num_labels
|
||||
self.num_choices = num_choices
|
||||
self.scope = scope
|
||||
self.position_embedding_type = position_embedding_type
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
@@ -116,6 +130,7 @@ class EsmModelTester:
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
type_vocab_size=self.type_vocab_size,
|
||||
initializer_range=self.initializer_range,
|
||||
position_embedding_type=self.position_embedding_type,
|
||||
)
|
||||
|
||||
def create_and_check_model(self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||
@@ -296,6 +311,39 @@ class EsmModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
def test_resize_tokens_embeddings(self):
|
||||
pass
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@pytest.mark.flash_attn_test
|
||||
@is_flaky()
|
||||
@slow
|
||||
def test_flash_attn_2_equivalence(self):
|
||||
for model_class in self.all_model_classes:
|
||||
if not model_class._supports_flash_attn_2:
|
||||
self.skipTest(reason="Model does not support Flash Attention 2")
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
model = model_class(config)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
model_fa = model_class.from_pretrained(
|
||||
tmpdirname, torch_dtype=torch.float16, attn_implementation="flash_attention_2"
|
||||
)
|
||||
model_fa.to(torch_device)
|
||||
|
||||
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, attn_implementation="eager")
|
||||
model.to(torch_device)
|
||||
|
||||
dummy_input = inputs_dict[model_class.main_input_name]
|
||||
dummy_input = dummy_input.to(torch_device)
|
||||
outputs = model(dummy_input, output_hidden_states=True)
|
||||
outputs_fa = model_fa(dummy_input, output_hidden_states=True)
|
||||
|
||||
logits = outputs.hidden_states[-1]
|
||||
logits_fa = outputs_fa.hidden_states[-1]
|
||||
|
||||
torch.testing.assert_close(logits_fa, logits, atol=1e-2, rtol=1e-3)
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
|
||||
Reference in New Issue
Block a user