From a5bee89c9d5ec2402cb72e819860d51cc2ca35fc Mon Sep 17 00:00:00 2001 From: Yoach Lacombe <52246514+ylacombe@users.noreply.github.com> Date: Wed, 8 Nov 2023 17:06:35 +0000 Subject: [PATCH] Add Flash Attention 2 support to Bark (#27364) * change handmade attention mask to _prepare_4d_attention_mask * add flashattention2 support in Bark * add flashattention2 tests on BarkSemanticModel * make style * fix flashattention and tests + make style * fix memory leak and allow Bark to pass flash attention to sub-models * make style * Apply suggestions from code review Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * remove unecessary code from tests + justify overriding * Update tests/models/bark/test_modeling_bark.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * make style --------- Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- src/transformers/models/bark/modeling_bark.py | 256 ++++++++++++++++-- tests/models/bark/test_modeling_bark.py | 119 ++++++++ 2 files changed, 355 insertions(+), 20 deletions(-) diff --git a/src/transformers/models/bark/modeling_bark.py b/src/transformers/models/bark/modeling_bark.py index 2c04f15c04..f8b9eab5d3 100644 --- a/src/transformers/models/bark/modeling_bark.py +++ b/src/transformers/models/bark/modeling_bark.py @@ -26,12 +26,14 @@ from ...generation.logits_process import ( BarkEosPrioritizerLogitsProcessor, SuppressTokensLogitsProcessor, ) +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask from ...modeling_outputs import CausalLMOutputWithPast, MaskedLMOutput from ...modeling_utils import PreTrainedModel, get_parameter_device from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, is_accelerate_available, + is_flash_attn_2_available, logging, ) from ..auto import AutoModel @@ -49,6 +51,11 @@ from .generation_configuration_bark import ( ) +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + + logger = logging.get_logger(__name__) @@ -62,6 +69,19 @@ BARK_PRETRAINED_MODEL_ARCHIVE_LIST = [ ] +# Copied from transformers.models.llama.modeling_llama._get_unpad_data +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + class BarkSelfAttention(nn.Module): # adapted from GPTNeoSelfAttention and Bark code # BarkSelfAttention can have two attention type, i.e full attention or causal attention @@ -187,6 +207,177 @@ class BarkSelfAttention(nn.Module): return outputs +class BarkSelfFlashAttention2(BarkSelfAttention): + """ + Bark flash attention module. This module inherits from `BarkSelfAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def _split_heads(self, tensor, num_heads, attn_head_size): + """ + Splits hidden_size dim into attn_head_size and num_heads + """ + new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) + tensor = tensor.view(new_shape) + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim - (batch, seq_length, head, head_features) + return tensor + + def _merge_heads(self, tensor, num_heads, attn_head_size): + """ + Merges attn_head_size dim and num_attn_heads dim into hidden_size + """ + # re-assemble all head outputs side by side + # (batch, seq_len, num_heads, attn_head_size) -> (batch, seq_len, num_heads*attn_head_size) + tensor = tensor.view(tensor.size()[:-2] + (num_heads * attn_head_size,)) + return tensor + + def forward( + self, + hidden_states, + attention_mask=None, + past_key_values=None, + head_mask=None, + use_cache=False, + output_attentions=False, + ): + batch_size, query_len, _ = hidden_states.size() + + # calculate query, key, values for all heads in batch and move head forward to be the batch dim + query, key, value = self.att_proj(hidden_states).split(self.embed_dim, dim=2) + + query = self._split_heads(query, self.num_heads, self.head_dim) + key = self._split_heads(key, self.num_heads, self.head_dim) + value = self._split_heads(value, self.num_heads, self.head_dim) + + if past_key_values is not None: + # (batch, head, seq_length, head_features) -> (batch, seq_length, head, head_features) + past_key = past_key_values[0].transpose(1, 2) + past_value = past_key_values[1].transpose(1, 2) + # and merge on seq_length + key = torch.cat((past_key, key), dim=1) + value = torch.cat((past_value, value), dim=1) + + if use_cache is True: + # (batch, head, seq_length, head_features) + present = (key.transpose(1, 2), value.transpose(1, 2)) + else: + present = None + + attn_output = self._flash_attention_forward(query, key, value, attention_mask, query_len, dropout=self.dropout) + + attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim) + attn_output = self.out_proj(attn_output) + attn_output = self.resid_dropout(attn_output) + + outputs = (attn_output, present) + if output_attentions: + attn_weights = None + outputs += (attn_weights,) + + return outputs + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward + def _flash_attention_forward( + self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`int`, *optional*): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=self.is_causal, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + attn_output = flash_attn_func( + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=self.is_causal + ) + + return attn_output + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +BARK_ATTENTION_CLASSES = { + "default": BarkSelfAttention, + "flash_attention_2": BarkSelfFlashAttention2, +} + + class BarkLayerNorm(nn.Module): """LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False.""" @@ -229,7 +420,8 @@ class BarkBlock(nn.Module): self.layernorm_1 = nn.LayerNorm(config.hidden_size) self.layernorm_2 = nn.LayerNorm(config.hidden_size) - self.attn = BarkSelfAttention(config, is_causal=is_causal) + attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default" + self.attn = BARK_ATTENTION_CLASSES[attn_type](config, is_causal=is_causal) self.mlp = BarkMLP(config) @@ -277,6 +469,7 @@ class BarkPreTrainedModel(PreTrainedModel): config_class = BarkConfig supports_gradient_checkpointing = False + _supports_flash_attn_2 = True def _init_weights(self, module): """Initialize the weights.""" @@ -596,21 +789,13 @@ class BarkCausalModel(BarkPreTrainedModel): if attention_mask is not None: if batch_size <= 0: raise ValueError("batch_size has to be defined and > 0") - attention_mask = attention_mask.view(batch_size, -1) - # We create a 3D attention mask from a 2D tensor mask. - # Sizes are [batch_size, 1, 1, to_seq_length] - # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] - # this attention mask is more simple than the triangular masking of causal attention - # used in OpenAI GPT, we just need to prepare the broadcast dimension here. - attention_mask = attention_mask[:, None, None, :] - - # Since attention_mask is 1.0 for positions we want to attend and 0.0 for - # masked positions, this operation will create a tensor which is 0.0 for - # positions we want to attend and the dtype's smallest value for masked positions. - # Since we are adding it to the raw scores before the softmax, this is - # effectively the same as removing these entirely. - attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility - attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min + if getattr(self.config, "_flash_attn_2_enabled", False): + attention_mask = attention_mask if 0 in attention_mask else None + else: + attention_mask = attention_mask.view(batch_size, -1) + # [bsz, to_seq_length] -> [bsz, 1, 1, to_seq_length] + # from_seq_length is 1 to easily broadcast + attention_mask = _prepare_4d_attention_mask(attention_mask, input_embeds.dtype, tgt_len=1) # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head @@ -1233,10 +1418,12 @@ class BarkFineModel(BarkPreTrainedModel): if attention_mask is not None: if batch_size <= 0: raise ValueError("batch_size has to be defined and > 0") - attention_mask = attention_mask.view(batch_size, -1) - attention_mask = attention_mask[:, None, None, :] - attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility - attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min + if getattr(self.config, "_flash_attn_2_enabled", False): + attention_mask = attention_mask if 0 in attention_mask else None + else: + # [bsz, to_seq_length] -> [bsz, 1, 1, to_seq_length] + # from_seq_length is 1 to easily broadcast + attention_mask = _prepare_4d_attention_mask(attention_mask, input_embeds.dtype, tgt_len=1) head_mask = self.get_head_mask(head_mask, self.config.num_layers) @@ -1669,3 +1856,32 @@ class BarkModel(BarkPreTrainedModel): return audio, output_lengths return audio + + @classmethod + def _check_and_enable_flash_attn_2( + cls, config, torch_dtype: Optional[torch.dtype] = None, device_map: Optional[Union[str, Dict[str, int]]] = None + ): + """ + `_check_and_enable_flash_attn_2` originally don't expand flash attention enabling to the model + sub-configurations. We override the original method to make sure that Bark sub-models are using Flash Attention + if necessary. + + If you don't know about Flash Attention, check out the official repository of flash attention: + https://github.com/Dao-AILab/flash-attention + + For using Flash Attention 1.0 you can do it directly via the `BetterTransformer` API, have a look at this + specific section of the documentation to learn more about it: + https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#decoder-models + + The method checks if the current setup is compatible with Flash Attention as it requires the model to be in + half precision and not ran on CPU. + + If all checks pass, the method will create an attribute in the config `_flash_attn_2_enabled` so that the model + can initialize the correct attention module + """ + config = super()._check_and_enable_flash_attn_2(config, torch_dtype, device_map) + + config.semantic_config._flash_attn_2_enabled = getattr(config, "_flash_attn_2_enabled", False) + config.coarse_acoustics_config._flash_attn_2_enabled = getattr(config, "_flash_attn_2_enabled", False) + config.fine_acoustics_config._flash_attn_2_enabled = getattr(config, "_flash_attn_2_enabled", False) + return config diff --git a/tests/models/bark/test_modeling_bark.py b/tests/models/bark/test_modeling_bark.py index bf13203ecd..a8545ad3c0 100644 --- a/tests/models/bark/test_modeling_bark.py +++ b/tests/models/bark/test_modeling_bark.py @@ -20,6 +20,8 @@ import inspect import tempfile import unittest +from pytest import mark + from transformers import ( BarkCoarseConfig, BarkConfig, @@ -33,6 +35,7 @@ from transformers.models.bark.generation_configuration_bark import ( BarkSemanticGenerationConfig, ) from transformers.testing_utils import ( + require_flash_attn, require_torch, require_torch_fp16, require_torch_gpu, @@ -872,6 +875,122 @@ class BarkFineModelTest(ModelTesterMixin, unittest.TestCase): # Check that the model can still do a forward pass successfully (every parameter should be resized) model(**self._prepare_for_class(inputs_dict, model_class)) + @require_flash_attn + @require_torch_gpu + @mark.flash_attn_test + @slow + def test_flash_attn_2_inference(self): + for model_class in self.all_model_classes: + if not model_class._supports_flash_attn_2: + return + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model_fa = model_class.from_pretrained( + tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=True + ) + model_fa.to(torch_device) + + model = model_class.from_pretrained( + tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=False + ) + model.to(torch_device) + + dummy_input = inputs_dict["input_ids"][:1] + if dummy_input.dtype in [torch.float32, torch.float16]: + dummy_input = dummy_input.to(torch.bfloat16) + + dummy_attention_mask = inputs_dict.get("attention_mask", None) + + if dummy_attention_mask is not None: + dummy_attention_mask = dummy_attention_mask[:1] + dummy_attention_mask[:, 1:] = 1 + dummy_attention_mask[:, :1] = 0 + + outputs = model(inputs_dict["codebook_idx"], dummy_input, output_hidden_states=True) + outputs_fa = model_fa(inputs_dict["codebook_idx"], dummy_input, output_hidden_states=True) + + logits = outputs.hidden_states[-1] + logits_fa = outputs_fa.hidden_states[-1] + + assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2) + + other_inputs = {"output_hidden_states": True} + if dummy_attention_mask is not None: + other_inputs["attention_mask"] = dummy_attention_mask + + outputs = model(inputs_dict["codebook_idx"], dummy_input, **other_inputs) + outputs_fa = model_fa(inputs_dict["codebook_idx"], dummy_input, **other_inputs) + + logits = outputs.hidden_states[-1] + logits_fa = outputs_fa.hidden_states[-1] + + assert torch.allclose(logits_fa[1:], logits[1:], atol=4e-2, rtol=4e-2) + + # check with inference + dropout + model.train() + _ = model_fa(inputs_dict["codebook_idx"], dummy_input, **other_inputs) + + @require_flash_attn + @require_torch_gpu + @mark.flash_attn_test + @slow + def test_flash_attn_2_inference_padding_right(self): + for model_class in self.all_model_classes: + if not model_class._supports_flash_attn_2: + return + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model_fa = model_class.from_pretrained( + tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=True + ) + model_fa.to(torch_device) + + model = model_class.from_pretrained( + tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=False + ) + model.to(torch_device) + + dummy_input = inputs_dict["input_ids"][:1] + if dummy_input.dtype in [torch.float32, torch.float16]: + dummy_input = dummy_input.to(torch.bfloat16) + + dummy_attention_mask = inputs_dict.get("attention_mask", None) + + if dummy_attention_mask is not None: + dummy_attention_mask = dummy_attention_mask[:1] + dummy_attention_mask[:, :-1] = 1 + dummy_attention_mask[:, -1:] = 0 + + outputs = model(inputs_dict["codebook_idx"], dummy_input, output_hidden_states=True) + outputs_fa = model_fa(inputs_dict["codebook_idx"], dummy_input, output_hidden_states=True) + + logits = outputs.hidden_states[-1] + logits_fa = outputs_fa.hidden_states[-1] + + assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2) + + other_inputs = { + "output_hidden_states": True, + } + if dummy_attention_mask is not None: + other_inputs["attention_mask"] = dummy_attention_mask + + outputs = model(inputs_dict["codebook_idx"], dummy_input, **other_inputs) + outputs_fa = model_fa(inputs_dict["codebook_idx"], dummy_input, **other_inputs) + + logits = outputs.hidden_states[-1] + logits_fa = outputs_fa.hidden_states[-1] + + assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2) + @require_torch class BarkModelIntegrationTests(unittest.TestCase):