Add Flash Attention 2 support to Bark (#27364)

* change handmade attention mask to _prepare_4d_attention_mask

* add flashattention2 support in Bark

* add flashattention2 tests on BarkSemanticModel

* make style

* fix flashattention and tests + make style

* fix memory leak and allow Bark to pass flash attention to sub-models

* make style

* Apply suggestions from code review

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* remove unecessary code from tests + justify overriding

* Update tests/models/bark/test_modeling_bark.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* make style

---------

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
Yoach Lacombe
2023-11-08 17:06:35 +00:00
committed by GitHub
parent ef71673616
commit a5bee89c9d
2 changed files with 355 additions and 20 deletions

View File

@@ -20,6 +20,8 @@ import inspect
import tempfile
import unittest
from pytest import mark
from transformers import (
BarkCoarseConfig,
BarkConfig,
@@ -33,6 +35,7 @@ from transformers.models.bark.generation_configuration_bark import (
BarkSemanticGenerationConfig,
)
from transformers.testing_utils import (
require_flash_attn,
require_torch,
require_torch_fp16,
require_torch_gpu,
@@ -872,6 +875,122 @@ class BarkFineModelTest(ModelTesterMixin, unittest.TestCase):
# Check that the model can still do a forward pass successfully (every parameter should be resized)
model(**self._prepare_for_class(inputs_dict, model_class))
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
def test_flash_attn_2_inference(self):
for model_class in self.all_model_classes:
if not model_class._supports_flash_attn_2:
return
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_fa = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=True
)
model_fa.to(torch_device)
model = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=False
)
model.to(torch_device)
dummy_input = inputs_dict["input_ids"][:1]
if dummy_input.dtype in [torch.float32, torch.float16]:
dummy_input = dummy_input.to(torch.bfloat16)
dummy_attention_mask = inputs_dict.get("attention_mask", None)
if dummy_attention_mask is not None:
dummy_attention_mask = dummy_attention_mask[:1]
dummy_attention_mask[:, 1:] = 1
dummy_attention_mask[:, :1] = 0
outputs = model(inputs_dict["codebook_idx"], dummy_input, output_hidden_states=True)
outputs_fa = model_fa(inputs_dict["codebook_idx"], dummy_input, output_hidden_states=True)
logits = outputs.hidden_states[-1]
logits_fa = outputs_fa.hidden_states[-1]
assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2)
other_inputs = {"output_hidden_states": True}
if dummy_attention_mask is not None:
other_inputs["attention_mask"] = dummy_attention_mask
outputs = model(inputs_dict["codebook_idx"], dummy_input, **other_inputs)
outputs_fa = model_fa(inputs_dict["codebook_idx"], dummy_input, **other_inputs)
logits = outputs.hidden_states[-1]
logits_fa = outputs_fa.hidden_states[-1]
assert torch.allclose(logits_fa[1:], logits[1:], atol=4e-2, rtol=4e-2)
# check with inference + dropout
model.train()
_ = model_fa(inputs_dict["codebook_idx"], dummy_input, **other_inputs)
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
def test_flash_attn_2_inference_padding_right(self):
for model_class in self.all_model_classes:
if not model_class._supports_flash_attn_2:
return
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_fa = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=True
)
model_fa.to(torch_device)
model = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=False
)
model.to(torch_device)
dummy_input = inputs_dict["input_ids"][:1]
if dummy_input.dtype in [torch.float32, torch.float16]:
dummy_input = dummy_input.to(torch.bfloat16)
dummy_attention_mask = inputs_dict.get("attention_mask", None)
if dummy_attention_mask is not None:
dummy_attention_mask = dummy_attention_mask[:1]
dummy_attention_mask[:, :-1] = 1
dummy_attention_mask[:, -1:] = 0
outputs = model(inputs_dict["codebook_idx"], dummy_input, output_hidden_states=True)
outputs_fa = model_fa(inputs_dict["codebook_idx"], dummy_input, output_hidden_states=True)
logits = outputs.hidden_states[-1]
logits_fa = outputs_fa.hidden_states[-1]
assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2)
other_inputs = {
"output_hidden_states": True,
}
if dummy_attention_mask is not None:
other_inputs["attention_mask"] = dummy_attention_mask
outputs = model(inputs_dict["codebook_idx"], dummy_input, **other_inputs)
outputs_fa = model_fa(inputs_dict["codebook_idx"], dummy_input, **other_inputs)
logits = outputs.hidden_states[-1]
logits_fa = outputs_fa.hidden_states[-1]
assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2)
@require_torch
class BarkModelIntegrationTests(unittest.TestCase):