[FA-2] Add Flash Attention to Phi (#27661)

* add FA and modify doc file

* test_flash_attn_2_generate_padding_right test overwritten

* comment

* modify persimmon modeling file

* added speedup graph

* more changes
This commit is contained in:
Susnato Dhar
2023-12-07 12:27:48 +05:30
committed by GitHub
parent 06f561687c
commit f84d85ba67
4 changed files with 345 additions and 10 deletions

View File

@@ -18,8 +18,17 @@
import unittest
import pytest
from transformers import PhiConfig, is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
from transformers.testing_utils import (
require_bitsandbytes,
require_flash_attn,
require_torch,
require_torch_gpu,
slow,
torch_device,
)
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
@@ -31,6 +40,7 @@ if is_torch_available():
import torch
from transformers import (
AutoTokenizer,
PhiForCausalLM,
PhiForSequenceClassification,
PhiForTokenClassification,
@@ -350,6 +360,43 @@ class PhiModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
@require_flash_attn
@require_torch_gpu
@require_bitsandbytes
@pytest.mark.flash_attn_test
@slow
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_flash_attn_2_generate_padding_right with LlamaForCausalLM->PhiForCausalLM,LlamaTokenizer->AutoTokenizer,meta-llama/Llama-2-7b-hf->susnato/phi-1_5_dev
def test_flash_attn_2_generate_padding_right(self):
"""
Overwritting the common test as the test is flaky on tiny models
"""
model = PhiForCausalLM.from_pretrained(
"susnato/phi-1_5_dev",
load_in_4bit=True,
device_map={"": 0},
)
tokenizer = AutoTokenizer.from_pretrained("susnato/phi-1_5_dev")
texts = ["hi", "Hello this is a very long sentence"]
tokenizer.padding_side = "right"
tokenizer.pad_token = tokenizer.eos_token
inputs = tokenizer(texts, return_tensors="pt", padding=True).to(0)
output_native = model.generate(**inputs, max_new_tokens=20, do_sample=False)
output_native = tokenizer.batch_decode(output_native)
model = PhiForCausalLM.from_pretrained(
"susnato/phi-1_5_dev", load_in_4bit=True, device_map={"": 0}, use_flash_attention_2=True
)
output_fa_2 = model.generate(**inputs, max_new_tokens=20, do_sample=False)
output_fa_2 = tokenizer.batch_decode(output_fa_2)
self.assertListEqual(output_native, output_fa_2)
@slow
@require_torch