[FA2] Add flash attention for for DistilBert (#26489)

* flash attention added for DistilBert

* fixes

* removed padding_masks

* Update modeling_distilbert.py

* Update test_modeling_distilbert.py

* style fix
This commit is contained in:
Susnato Dhar
2023-11-03 21:37:54 +05:30
committed by GitHub
parent 5964f820db
commit 1ac2463dfe
3 changed files with 348 additions and 5 deletions

View File

@@ -16,8 +16,10 @@ import os
import tempfile
import unittest
from pytest import mark
from transformers import DistilBertConfig, is_torch_available
from transformers.testing_utils import require_torch, require_torch_accelerator, slow, torch_device
from transformers.testing_utils import require_flash_attn, require_torch, require_torch_accelerator, slow, torch_device
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
@@ -285,6 +287,114 @@ class DistilBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
loaded = torch.jit.load(os.path.join(tmp, "traced_model.pt"), map_location=torch_device)
loaded(inputs_dict["input_ids"].to(torch_device), inputs_dict["attention_mask"].to(torch_device))
# Because DistilBertForMultipleChoice requires inputs with different shapes we need to override this test.
@require_flash_attn
@require_torch_accelerator
@mark.flash_attn_test
@slow
def test_flash_attn_2_inference(self):
import torch
for model_class in self.all_model_classes:
dummy_input = torch.LongTensor(
[
[1, 2, 3, 4],
[1, 2, 8, 9],
[1, 2, 11, 12],
[1, 2, 13, 14],
]
).to(torch_device)
dummy_attention_mask = torch.LongTensor(
[
[0, 1, 1, 1],
[0, 1, 1, 1],
[0, 1, 1, 1],
[0, 1, 1, 1],
]
).to(torch_device)
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_fa = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=True
)
model_fa.to(torch_device)
model = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=False
)
model.to(torch_device)
logits = model(dummy_input, output_hidden_states=True).hidden_states[-1]
logits_fa = model_fa(dummy_input, output_hidden_states=True).hidden_states[-1]
self.assertTrue(torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2))
output_fa = model_fa(dummy_input, attention_mask=dummy_attention_mask, output_hidden_states=True)
logits_fa = output_fa.hidden_states[-1]
output = model(dummy_input, attention_mask=dummy_attention_mask, output_hidden_states=True)
logits = output.hidden_states[-1]
self.assertTrue(torch.allclose(logits_fa[1:], logits[1:], atol=4e-2, rtol=4e-2))
# Because DistilBertForMultipleChoice requires inputs with different shapes we need to override this test.
@require_flash_attn
@require_torch_accelerator
@mark.flash_attn_test
@slow
def test_flash_attn_2_inference_padding_right(self):
import torch
for model_class in self.all_model_classes:
dummy_input = torch.LongTensor(
[
[1, 2, 3, 4],
[1, 2, 8, 9],
[1, 2, 11, 12],
[1, 2, 13, 14],
]
).to(torch_device)
dummy_attention_mask = torch.LongTensor(
[
[0, 1, 1, 1],
[0, 1, 1, 1],
[0, 1, 1, 1],
[0, 1, 1, 1],
]
).to(torch_device)
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_fa = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=True
)
model_fa.to(torch_device)
model = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=False
)
model.to(torch_device)
logits = model(dummy_input, output_hidden_states=True).hidden_states[-1]
logits_fa = model_fa(dummy_input, output_hidden_states=True).hidden_states[-1]
self.assertTrue(torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2))
output_fa = model_fa(dummy_input, attention_mask=dummy_attention_mask, output_hidden_states=True)
logits_fa = output_fa.hidden_states[-1]
output = model(dummy_input, attention_mask=dummy_attention_mask, output_hidden_states=True)
logits = output.hidden_states[-1]
self.assertTrue(torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2))
@require_torch
class DistilBertModelIntergrationTest(unittest.TestCase):