Add Flash Attention 2 support to Bark (#27364)
* change handmade attention mask to _prepare_4d_attention_mask * add flashattention2 support in Bark * add flashattention2 tests on BarkSemanticModel * make style * fix flashattention and tests + make style * fix memory leak and allow Bark to pass flash attention to sub-models * make style * Apply suggestions from code review Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * remove unecessary code from tests + justify overriding * Update tests/models/bark/test_modeling_bark.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * make style --------- Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -20,6 +20,8 @@ import inspect
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from pytest import mark
|
||||
|
||||
from transformers import (
|
||||
BarkCoarseConfig,
|
||||
BarkConfig,
|
||||
@@ -33,6 +35,7 @@ from transformers.models.bark.generation_configuration_bark import (
|
||||
BarkSemanticGenerationConfig,
|
||||
)
|
||||
from transformers.testing_utils import (
|
||||
require_flash_attn,
|
||||
require_torch,
|
||||
require_torch_fp16,
|
||||
require_torch_gpu,
|
||||
@@ -872,6 +875,122 @@ class BarkFineModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
# Check that the model can still do a forward pass successfully (every parameter should be resized)
|
||||
model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@mark.flash_attn_test
|
||||
@slow
|
||||
def test_flash_attn_2_inference(self):
|
||||
for model_class in self.all_model_classes:
|
||||
if not model_class._supports_flash_attn_2:
|
||||
return
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
model = model_class(config)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
model_fa = model_class.from_pretrained(
|
||||
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=True
|
||||
)
|
||||
model_fa.to(torch_device)
|
||||
|
||||
model = model_class.from_pretrained(
|
||||
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=False
|
||||
)
|
||||
model.to(torch_device)
|
||||
|
||||
dummy_input = inputs_dict["input_ids"][:1]
|
||||
if dummy_input.dtype in [torch.float32, torch.float16]:
|
||||
dummy_input = dummy_input.to(torch.bfloat16)
|
||||
|
||||
dummy_attention_mask = inputs_dict.get("attention_mask", None)
|
||||
|
||||
if dummy_attention_mask is not None:
|
||||
dummy_attention_mask = dummy_attention_mask[:1]
|
||||
dummy_attention_mask[:, 1:] = 1
|
||||
dummy_attention_mask[:, :1] = 0
|
||||
|
||||
outputs = model(inputs_dict["codebook_idx"], dummy_input, output_hidden_states=True)
|
||||
outputs_fa = model_fa(inputs_dict["codebook_idx"], dummy_input, output_hidden_states=True)
|
||||
|
||||
logits = outputs.hidden_states[-1]
|
||||
logits_fa = outputs_fa.hidden_states[-1]
|
||||
|
||||
assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2)
|
||||
|
||||
other_inputs = {"output_hidden_states": True}
|
||||
if dummy_attention_mask is not None:
|
||||
other_inputs["attention_mask"] = dummy_attention_mask
|
||||
|
||||
outputs = model(inputs_dict["codebook_idx"], dummy_input, **other_inputs)
|
||||
outputs_fa = model_fa(inputs_dict["codebook_idx"], dummy_input, **other_inputs)
|
||||
|
||||
logits = outputs.hidden_states[-1]
|
||||
logits_fa = outputs_fa.hidden_states[-1]
|
||||
|
||||
assert torch.allclose(logits_fa[1:], logits[1:], atol=4e-2, rtol=4e-2)
|
||||
|
||||
# check with inference + dropout
|
||||
model.train()
|
||||
_ = model_fa(inputs_dict["codebook_idx"], dummy_input, **other_inputs)
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@mark.flash_attn_test
|
||||
@slow
|
||||
def test_flash_attn_2_inference_padding_right(self):
|
||||
for model_class in self.all_model_classes:
|
||||
if not model_class._supports_flash_attn_2:
|
||||
return
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
model = model_class(config)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
model_fa = model_class.from_pretrained(
|
||||
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=True
|
||||
)
|
||||
model_fa.to(torch_device)
|
||||
|
||||
model = model_class.from_pretrained(
|
||||
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=False
|
||||
)
|
||||
model.to(torch_device)
|
||||
|
||||
dummy_input = inputs_dict["input_ids"][:1]
|
||||
if dummy_input.dtype in [torch.float32, torch.float16]:
|
||||
dummy_input = dummy_input.to(torch.bfloat16)
|
||||
|
||||
dummy_attention_mask = inputs_dict.get("attention_mask", None)
|
||||
|
||||
if dummy_attention_mask is not None:
|
||||
dummy_attention_mask = dummy_attention_mask[:1]
|
||||
dummy_attention_mask[:, :-1] = 1
|
||||
dummy_attention_mask[:, -1:] = 0
|
||||
|
||||
outputs = model(inputs_dict["codebook_idx"], dummy_input, output_hidden_states=True)
|
||||
outputs_fa = model_fa(inputs_dict["codebook_idx"], dummy_input, output_hidden_states=True)
|
||||
|
||||
logits = outputs.hidden_states[-1]
|
||||
logits_fa = outputs_fa.hidden_states[-1]
|
||||
|
||||
assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2)
|
||||
|
||||
other_inputs = {
|
||||
"output_hidden_states": True,
|
||||
}
|
||||
if dummy_attention_mask is not None:
|
||||
other_inputs["attention_mask"] = dummy_attention_mask
|
||||
|
||||
outputs = model(inputs_dict["codebook_idx"], dummy_input, **other_inputs)
|
||||
outputs_fa = model_fa(inputs_dict["codebook_idx"], dummy_input, **other_inputs)
|
||||
|
||||
logits = outputs.hidden_states[-1]
|
||||
logits_fa = outputs_fa.hidden_states[-1]
|
||||
|
||||
assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2)
|
||||
|
||||
|
||||
@require_torch
|
||||
class BarkModelIntegrationTests(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user