[Phi] Extend implementation to use GQA/MQA. (#28163)
* chore(phi): Updates configuration_phi with missing keys. * chore(phi): Adds first draft of combined modeling_phi. * fix(phi): Fixes according to latest review. * fix(phi): Removes pad_vocab_size_multiple to prevent inconsistencies. * fix(phi): Fixes unit and integration tests. * fix(phi): Ensures that everything works with microsoft/phi-1 for first integration. * fix(phi): Fixes output of docstring generation. * fix(phi): Fixes according to latest review. * fix(phi): Fixes according to latest review. * fix(tests): Re-enables Phi-1.5 test. * fix(phi): Fixes attention overflow on PhiAttention (for Phi-2). * fix(phi): Improves how queries and keys are upcast. * fix(phi): Small updates on latest changes.
This commit is contained in:
@@ -365,18 +365,18 @@ class PhiModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
||||
@require_bitsandbytes
|
||||
@pytest.mark.flash_attn_test
|
||||
@slow
|
||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_flash_attn_2_generate_padding_right with LlamaForCausalLM->PhiForCausalLM,LlamaTokenizer->AutoTokenizer,meta-llama/Llama-2-7b-hf->susnato/phi-1_5_dev
|
||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_flash_attn_2_generate_padding_right with LlamaForCausalLM->PhiForCausalLM,LlamaTokenizer->AutoTokenizer,meta-llama/Llama-2-7b-hf->microsoft/phi-1
|
||||
def test_flash_attn_2_generate_padding_right(self):
|
||||
"""
|
||||
Overwritting the common test as the test is flaky on tiny models
|
||||
"""
|
||||
model = PhiForCausalLM.from_pretrained(
|
||||
"susnato/phi-1_5_dev",
|
||||
"microsoft/phi-1",
|
||||
load_in_4bit=True,
|
||||
device_map={"": 0},
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("susnato/phi-1_5_dev")
|
||||
tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-1")
|
||||
|
||||
texts = ["hi", "Hello this is a very long sentence"]
|
||||
|
||||
@@ -389,7 +389,7 @@ class PhiModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
||||
output_native = tokenizer.batch_decode(output_native)
|
||||
|
||||
model = PhiForCausalLM.from_pretrained(
|
||||
"susnato/phi-1_5_dev", load_in_4bit=True, device_map={"": 0}, attn_implementation="flash_attention_2"
|
||||
"microsoft/phi-1", load_in_4bit=True, device_map={"": 0}, attn_implementation="flash_attention_2"
|
||||
)
|
||||
|
||||
output_fa_2 = model.generate(**inputs, max_new_tokens=20, do_sample=False)
|
||||
@@ -408,7 +408,7 @@ class PhiIntegrationTest(unittest.TestCase):
|
||||
)
|
||||
}
|
||||
|
||||
model = PhiForCausalLM.from_pretrained("susnato/phi-1_dev").to(torch_device)
|
||||
model = PhiForCausalLM.from_pretrained("microsoft/phi-1").to(torch_device)
|
||||
model.eval()
|
||||
|
||||
output = model(**input_ids).logits
|
||||
@@ -424,7 +424,7 @@ class PhiIntegrationTest(unittest.TestCase):
|
||||
)
|
||||
}
|
||||
|
||||
model = PhiForCausalLM.from_pretrained("susnato/phi-1_5_dev").to(torch_device)
|
||||
model = PhiForCausalLM.from_pretrained("microsoft/phi-1_5").to(torch_device)
|
||||
model.eval()
|
||||
|
||||
output = model(**input_ids).logits
|
||||
@@ -440,7 +440,7 @@ class PhiIntegrationTest(unittest.TestCase):
|
||||
)
|
||||
}
|
||||
|
||||
model = PhiForCausalLM.from_pretrained("susnato/phi-2").to(torch_device)
|
||||
model = PhiForCausalLM.from_pretrained("microsoft/phi-2").to(torch_device)
|
||||
model.eval()
|
||||
|
||||
output = model(**input_ids).logits
|
||||
@@ -450,8 +450,8 @@ class PhiIntegrationTest(unittest.TestCase):
|
||||
self.assertTrue(torch.allclose(EXPECTED_OUTPUT, output[0, :2, :30], atol=1e-3, rtol=1e-3))
|
||||
|
||||
def test_phi_2_generation(self):
|
||||
model = PhiForCausalLM.from_pretrained("susnato/phi-2")
|
||||
tokenizer = AutoTokenizer.from_pretrained("susnato/phi-2")
|
||||
model = PhiForCausalLM.from_pretrained("microsoft/phi-2")
|
||||
tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2")
|
||||
|
||||
inputs = tokenizer(
|
||||
"Can you help me write a formal email to a potential business partner proposing a joint venture?",
|
||||
|
||||
Reference in New Issue
Block a user