Remove size check between attn_weights and kv_seq_len for phi3 (#32339)
* Remove size check between attn_weights and kv_seq_len * add unit tests
This commit is contained in:
@@ -19,7 +19,7 @@ import unittest
|
||||
|
||||
from parameterized import parameterized
|
||||
|
||||
from transformers import Phi3Config, is_torch_available, set_seed
|
||||
from transformers import Phi3Config, StaticCache, is_torch_available, set_seed
|
||||
from transformers.testing_utils import (
|
||||
require_torch,
|
||||
slow,
|
||||
@@ -43,6 +43,55 @@ if is_torch_available():
|
||||
Phi3Model,
|
||||
)
|
||||
|
||||
end_of_text_token = 32000
|
||||
|
||||
class Phi3MiniWithStaticCache(torch.nn.Module):
|
||||
def __init__(self, model: Phi3ForCausalLM, max_batch_size: int, max_seq_len: int):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.cache = StaticCache(
|
||||
config=model.config,
|
||||
max_batch_size=max_batch_size,
|
||||
max_cache_len=max_seq_len,
|
||||
device=self.model.device,
|
||||
dtype=self.model.dtype,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
) -> torch.FloatTensor:
|
||||
return self.model.forward(
|
||||
input_ids=input_ids,
|
||||
use_cache=True,
|
||||
return_dict=True,
|
||||
past_key_values=self.cache,
|
||||
).logits
|
||||
|
||||
@staticmethod
|
||||
def generate(model: Phi3ForCausalLM, prompt_tokens: torch.LongTensor, max_seq_len: int) -> list[int]:
|
||||
model = Phi3MiniWithStaticCache(model, 1, max_seq_len + prompt_tokens.shape[-1])
|
||||
|
||||
response_tokens = []
|
||||
|
||||
for input_pos in range(prompt_tokens.shape[-1]):
|
||||
result = model.forward(
|
||||
input_ids=prompt_tokens[:, input_pos : input_pos + 1],
|
||||
)
|
||||
response_tokens.append(prompt_tokens[0][input_pos].item())
|
||||
|
||||
current_token = torch.argmax(result[:, -1, :], dim=-1).item()
|
||||
response_tokens.append(current_token)
|
||||
|
||||
while current_token != end_of_text_token and len(response_tokens) < max_seq_len:
|
||||
result = model.forward(
|
||||
input_ids=torch.tensor([[current_token]], dtype=torch.long),
|
||||
)
|
||||
current_token = torch.argmax(result[:, -1, :], dim=-1).item()
|
||||
response_tokens.append(current_token)
|
||||
|
||||
return response_tokens
|
||||
|
||||
|
||||
class Phi3ModelTester:
|
||||
def __init__(
|
||||
@@ -429,7 +478,30 @@ class Phi3IntegrationTest(unittest.TestCase):
|
||||
output_text = tokenizer.batch_decode(outputs)
|
||||
|
||||
EXPECTED_OUTPUT = [
|
||||
"<s><|system|> You are a helpful digital assistant. Please provide safe, ethical and accurate information to the user.<|end|><|user|> Can you provide ways to eat combinations of bananas and dragonfruits?<|end|><|assistant|> Absolutely! Bananas and dragonfruits are both delicious fruits that can be combined in various ways to create tasty and nutrit"
|
||||
"<|system|> You are a helpful digital assistant. Please provide safe, ethical and accurate information to the user.<|end|><|user|> Can you provide ways to eat combinations of bananas and dragonfruits?<|end|><|assistant|> Certainly! Bananas and dragonfruits can be combined in various delicious ways. Here are some ideas for incorporating these fruits into your"
|
||||
]
|
||||
|
||||
self.assertListEqual(output_text, EXPECTED_OUTPUT)
|
||||
|
||||
def test_phi3_mini_4k_instruct_with_static_cache(self):
|
||||
model = Phi3ForCausalLM.from_pretrained("microsoft/phi-3-mini-4k-instruct")
|
||||
tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-3-mini-4k-instruct")
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful digital assistant. Please provide safe, ethical and accurate information to the user.",
|
||||
},
|
||||
{"role": "user", "content": "Can you provide ways to eat combinations of bananas and dragonfruits?"},
|
||||
]
|
||||
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt")
|
||||
|
||||
response_tokens = Phi3MiniWithStaticCache.generate(model, inputs, 64)
|
||||
|
||||
output_text = tokenizer.batch_decode(torch.tensor([response_tokens], dtype=torch.long, device=torch_device))
|
||||
|
||||
EXPECTED_OUTPUT = [
|
||||
"<|system|> You are a helpful digital assistant. Please provide safe, ethical and accurate information to the user.<|end|><|user|> Can you provide ways to eat combinations of bananas and dragonfruits?<|end|><|assistant|> Certainly! Bananas and dragonfruits can be combined in various delicious ways. Here are some"
|
||||
]
|
||||
|
||||
self.assertListEqual(output_text, EXPECTED_OUTPUT)
|
||||
@@ -467,7 +539,30 @@ class Phi3IntegrationTest(unittest.TestCase):
|
||||
output_text = tokenizer.batch_decode(outputs)
|
||||
|
||||
EXPECTED_OUTPUT = [
|
||||
"<s><|system|> You are a helpful digital assistant. Please provide safe, ethical and accurate information to the user.<|end|><|user|> Can you provide ways to eat combinations of bananas and dragonfruits?<|end|><|assistant|> Certainly! Bananas and dragonfruits can be combined in various delicious and healthy ways. Here are some ideas:\n\n1."
|
||||
"<|system|> You are a helpful digital assistant. Please provide safe, ethical and accurate information to the user.<|end|><|user|> Can you provide ways to eat combinations of bananas and dragonfruits?<|end|><|assistant|> Certainly! Bananas and dragonfruits can be combined in various delicious and nutritious ways. Here are some creative and healthy"
|
||||
]
|
||||
|
||||
self.assertListEqual(output_text, EXPECTED_OUTPUT)
|
||||
|
||||
def test_phi3_mini_128k_instruct_with_static_cache(self):
|
||||
model = Phi3ForCausalLM.from_pretrained("microsoft/phi-3-mini-128k-instruct")
|
||||
tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-3-mini-128k-instruct")
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful digital assistant. Please provide safe, ethical and accurate information to the user.",
|
||||
},
|
||||
{"role": "user", "content": "Can you provide ways to eat combinations of bananas and dragonfruits?"},
|
||||
]
|
||||
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt")
|
||||
|
||||
response_tokens = Phi3MiniWithStaticCache.generate(model, inputs, 64)
|
||||
|
||||
output_text = tokenizer.batch_decode(torch.tensor([response_tokens], dtype=torch.long, device=torch_device))
|
||||
|
||||
EXPECTED_OUTPUT = [
|
||||
"<|system|> You are a helpful digital assistant. Please provide safe, ethical and accurate information to the user.<|end|><|user|> Can you provide ways to eat combinations of bananas and dragonfruits?<|end|><|assistant|> Certainly! Bananas and dragonfruits can be combined in various delicious and nutritious ways"
|
||||
]
|
||||
|
||||
self.assertListEqual(output_text, EXPECTED_OUTPUT)
|
||||
|
||||
Reference in New Issue
Block a user