[ESM] Add flash-attention-2 backend for ESM-2 (#38023)

* Add flash-attention-2 backend for ESM-2

Signed-off-by: Peter St. John <pstjohn@nvidia.com>

* update extended_attention_mask for fa2

Signed-off-by: Peter St. John <pstjohn@nvidia.com>

* add test_flash_attn_2_equivalence test

Signed-off-by: Peter St. John <pstjohn@nvidia.com>

---------

Signed-off-by: Peter St. John <pstjohn@nvidia.com>
This commit is contained in:
Peter St. John
2025-05-16 07:11:56 -06:00
committed by GitHub
parent 7b5e327c6e
commit d69945e5fc
3 changed files with 188 additions and 8 deletions

View File

@@ -13,10 +13,22 @@
# limitations under the License.
"""Testing suite for the PyTorch ESM model."""
import tempfile
import unittest
import pytest
from transformers import EsmConfig, is_torch_available
from transformers.testing_utils import TestCasePlus, require_bitsandbytes, require_torch, slow, torch_device
from transformers.testing_utils import (
TestCasePlus,
is_flaky,
require_bitsandbytes,
require_flash_attn,
require_torch,
require_torch_gpu,
slow,
torch_device,
)
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
@@ -59,6 +71,7 @@ class EsmModelTester:
num_labels=3,
num_choices=4,
scope=None,
position_embedding_type="rotary",
):
self.parent = parent
self.batch_size = batch_size
@@ -82,6 +95,7 @@ class EsmModelTester:
self.num_labels = num_labels
self.num_choices = num_choices
self.scope = scope
self.position_embedding_type = position_embedding_type
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
@@ -116,6 +130,7 @@ class EsmModelTester:
max_position_embeddings=self.max_position_embeddings,
type_vocab_size=self.type_vocab_size,
initializer_range=self.initializer_range,
position_embedding_type=self.position_embedding_type,
)
def create_and_check_model(self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels):
@@ -296,6 +311,39 @@ class EsmModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
def test_resize_tokens_embeddings(self):
pass
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test
@is_flaky()
@slow
def test_flash_attn_2_equivalence(self):
for model_class in self.all_model_classes:
if not model_class._supports_flash_attn_2:
self.skipTest(reason="Model does not support Flash Attention 2")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_fa = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.float16, attn_implementation="flash_attention_2"
)
model_fa.to(torch_device)
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, attn_implementation="eager")
model.to(torch_device)
dummy_input = inputs_dict[model_class.main_input_name]
dummy_input = dummy_input.to(torch_device)
outputs = model(dummy_input, output_hidden_states=True)
outputs_fa = model_fa(dummy_input, output_hidden_states=True)
logits = outputs.hidden_states[-1]
logits_fa = outputs_fa.hidden_states[-1]
torch.testing.assert_close(logits_fa, logits, atol=1e-2, rtol=1e-3)
@slow
@require_torch