[Phi] Extend implementation to use GQA/MQA. (#28163)

* chore(phi): Updates configuration_phi with missing keys.

* chore(phi): Adds first draft of combined modeling_phi.

* fix(phi): Fixes according to latest review.

* fix(phi): Removes pad_vocab_size_multiple to prevent inconsistencies.

* fix(phi): Fixes unit and integration tests.

* fix(phi): Ensures that everything works with microsoft/phi-1 for first integration.

* fix(phi): Fixes output of docstring generation.

* fix(phi): Fixes according to latest review.

* fix(phi): Fixes according to latest review.

* fix(tests): Re-enables Phi-1.5 test.

* fix(phi): Fixes attention overflow on PhiAttention (for Phi-2).

* fix(phi): Improves how queries and keys are upcast.

* fix(phi): Small updates on latest changes.
This commit is contained in:
Gustavo de Rosa
2024-01-11 11:58:02 -03:00
committed by GitHub
parent d560637885
commit 5509058561
3 changed files with 101 additions and 82 deletions

View File

@@ -365,18 +365,18 @@ class PhiModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
@require_bitsandbytes
@pytest.mark.flash_attn_test
@slow
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_flash_attn_2_generate_padding_right with LlamaForCausalLM->PhiForCausalLM,LlamaTokenizer->AutoTokenizer,meta-llama/Llama-2-7b-hf->susnato/phi-1_5_dev
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_flash_attn_2_generate_padding_right with LlamaForCausalLM->PhiForCausalLM,LlamaTokenizer->AutoTokenizer,meta-llama/Llama-2-7b-hf->microsoft/phi-1
def test_flash_attn_2_generate_padding_right(self):
"""
Overwritting the common test as the test is flaky on tiny models
"""
model = PhiForCausalLM.from_pretrained(
"susnato/phi-1_5_dev",
"microsoft/phi-1",
load_in_4bit=True,
device_map={"": 0},
)
tokenizer = AutoTokenizer.from_pretrained("susnato/phi-1_5_dev")
tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-1")
texts = ["hi", "Hello this is a very long sentence"]
@@ -389,7 +389,7 @@ class PhiModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
output_native = tokenizer.batch_decode(output_native)
model = PhiForCausalLM.from_pretrained(
"susnato/phi-1_5_dev", load_in_4bit=True, device_map={"": 0}, attn_implementation="flash_attention_2"
"microsoft/phi-1", load_in_4bit=True, device_map={"": 0}, attn_implementation="flash_attention_2"
)
output_fa_2 = model.generate(**inputs, max_new_tokens=20, do_sample=False)
@@ -408,7 +408,7 @@ class PhiIntegrationTest(unittest.TestCase):
)
}
model = PhiForCausalLM.from_pretrained("susnato/phi-1_dev").to(torch_device)
model = PhiForCausalLM.from_pretrained("microsoft/phi-1").to(torch_device)
model.eval()
output = model(**input_ids).logits
@@ -424,7 +424,7 @@ class PhiIntegrationTest(unittest.TestCase):
)
}
model = PhiForCausalLM.from_pretrained("susnato/phi-1_5_dev").to(torch_device)
model = PhiForCausalLM.from_pretrained("microsoft/phi-1_5").to(torch_device)
model.eval()
output = model(**input_ids).logits
@@ -440,7 +440,7 @@ class PhiIntegrationTest(unittest.TestCase):
)
}
model = PhiForCausalLM.from_pretrained("susnato/phi-2").to(torch_device)
model = PhiForCausalLM.from_pretrained("microsoft/phi-2").to(torch_device)
model.eval()
output = model(**input_ids).logits
@@ -450,8 +450,8 @@ class PhiIntegrationTest(unittest.TestCase):
self.assertTrue(torch.allclose(EXPECTED_OUTPUT, output[0, :2, :30], atol=1e-3, rtol=1e-3))
def test_phi_2_generation(self):
model = PhiForCausalLM.from_pretrained("susnato/phi-2")
tokenizer = AutoTokenizer.from_pretrained("susnato/phi-2")
model = PhiForCausalLM.from_pretrained("microsoft/phi-2")
tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2")
inputs = tokenizer(
"Can you help me write a formal email to a potential business partner proposing a joint venture?",