Remove size check between attn_weights and kv_seq_len for phi3 (#32339)
* Remove size check between attn_weights and kv_seq_len * add unit tests
This commit is contained in:
@@ -453,12 +453,6 @@ class Phi3Attention(nn.Module):
|
|||||||
|
|
||||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||||
|
|
||||||
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
|
||||||
raise ValueError(
|
|
||||||
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
|
||||||
f" {attn_weights.size()}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||||
attn_weights += causal_mask
|
attn_weights += causal_mask
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ import unittest
|
|||||||
|
|
||||||
from parameterized import parameterized
|
from parameterized import parameterized
|
||||||
|
|
||||||
from transformers import Phi3Config, is_torch_available, set_seed
|
from transformers import Phi3Config, StaticCache, is_torch_available, set_seed
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
require_torch,
|
require_torch,
|
||||||
slow,
|
slow,
|
||||||
@@ -43,6 +43,55 @@ if is_torch_available():
|
|||||||
Phi3Model,
|
Phi3Model,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
end_of_text_token = 32000
|
||||||
|
|
||||||
|
class Phi3MiniWithStaticCache(torch.nn.Module):
|
||||||
|
def __init__(self, model: Phi3ForCausalLM, max_batch_size: int, max_seq_len: int):
|
||||||
|
super().__init__()
|
||||||
|
self.model = model
|
||||||
|
self.cache = StaticCache(
|
||||||
|
config=model.config,
|
||||||
|
max_batch_size=max_batch_size,
|
||||||
|
max_cache_len=max_seq_len,
|
||||||
|
device=self.model.device,
|
||||||
|
dtype=self.model.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
) -> torch.FloatTensor:
|
||||||
|
return self.model.forward(
|
||||||
|
input_ids=input_ids,
|
||||||
|
use_cache=True,
|
||||||
|
return_dict=True,
|
||||||
|
past_key_values=self.cache,
|
||||||
|
).logits
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def generate(model: Phi3ForCausalLM, prompt_tokens: torch.LongTensor, max_seq_len: int) -> list[int]:
|
||||||
|
model = Phi3MiniWithStaticCache(model, 1, max_seq_len + prompt_tokens.shape[-1])
|
||||||
|
|
||||||
|
response_tokens = []
|
||||||
|
|
||||||
|
for input_pos in range(prompt_tokens.shape[-1]):
|
||||||
|
result = model.forward(
|
||||||
|
input_ids=prompt_tokens[:, input_pos : input_pos + 1],
|
||||||
|
)
|
||||||
|
response_tokens.append(prompt_tokens[0][input_pos].item())
|
||||||
|
|
||||||
|
current_token = torch.argmax(result[:, -1, :], dim=-1).item()
|
||||||
|
response_tokens.append(current_token)
|
||||||
|
|
||||||
|
while current_token != end_of_text_token and len(response_tokens) < max_seq_len:
|
||||||
|
result = model.forward(
|
||||||
|
input_ids=torch.tensor([[current_token]], dtype=torch.long),
|
||||||
|
)
|
||||||
|
current_token = torch.argmax(result[:, -1, :], dim=-1).item()
|
||||||
|
response_tokens.append(current_token)
|
||||||
|
|
||||||
|
return response_tokens
|
||||||
|
|
||||||
|
|
||||||
class Phi3ModelTester:
|
class Phi3ModelTester:
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -429,7 +478,30 @@ class Phi3IntegrationTest(unittest.TestCase):
|
|||||||
output_text = tokenizer.batch_decode(outputs)
|
output_text = tokenizer.batch_decode(outputs)
|
||||||
|
|
||||||
EXPECTED_OUTPUT = [
|
EXPECTED_OUTPUT = [
|
||||||
"<s><|system|> You are a helpful digital assistant. Please provide safe, ethical and accurate information to the user.<|end|><|user|> Can you provide ways to eat combinations of bananas and dragonfruits?<|end|><|assistant|> Absolutely! Bananas and dragonfruits are both delicious fruits that can be combined in various ways to create tasty and nutrit"
|
"<|system|> You are a helpful digital assistant. Please provide safe, ethical and accurate information to the user.<|end|><|user|> Can you provide ways to eat combinations of bananas and dragonfruits?<|end|><|assistant|> Certainly! Bananas and dragonfruits can be combined in various delicious ways. Here are some ideas for incorporating these fruits into your"
|
||||||
|
]
|
||||||
|
|
||||||
|
self.assertListEqual(output_text, EXPECTED_OUTPUT)
|
||||||
|
|
||||||
|
def test_phi3_mini_4k_instruct_with_static_cache(self):
|
||||||
|
model = Phi3ForCausalLM.from_pretrained("microsoft/phi-3-mini-4k-instruct")
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-3-mini-4k-instruct")
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "You are a helpful digital assistant. Please provide safe, ethical and accurate information to the user.",
|
||||||
|
},
|
||||||
|
{"role": "user", "content": "Can you provide ways to eat combinations of bananas and dragonfruits?"},
|
||||||
|
]
|
||||||
|
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt")
|
||||||
|
|
||||||
|
response_tokens = Phi3MiniWithStaticCache.generate(model, inputs, 64)
|
||||||
|
|
||||||
|
output_text = tokenizer.batch_decode(torch.tensor([response_tokens], dtype=torch.long, device=torch_device))
|
||||||
|
|
||||||
|
EXPECTED_OUTPUT = [
|
||||||
|
"<|system|> You are a helpful digital assistant. Please provide safe, ethical and accurate information to the user.<|end|><|user|> Can you provide ways to eat combinations of bananas and dragonfruits?<|end|><|assistant|> Certainly! Bananas and dragonfruits can be combined in various delicious ways. Here are some"
|
||||||
]
|
]
|
||||||
|
|
||||||
self.assertListEqual(output_text, EXPECTED_OUTPUT)
|
self.assertListEqual(output_text, EXPECTED_OUTPUT)
|
||||||
@@ -467,7 +539,30 @@ class Phi3IntegrationTest(unittest.TestCase):
|
|||||||
output_text = tokenizer.batch_decode(outputs)
|
output_text = tokenizer.batch_decode(outputs)
|
||||||
|
|
||||||
EXPECTED_OUTPUT = [
|
EXPECTED_OUTPUT = [
|
||||||
"<s><|system|> You are a helpful digital assistant. Please provide safe, ethical and accurate information to the user.<|end|><|user|> Can you provide ways to eat combinations of bananas and dragonfruits?<|end|><|assistant|> Certainly! Bananas and dragonfruits can be combined in various delicious and healthy ways. Here are some ideas:\n\n1."
|
"<|system|> You are a helpful digital assistant. Please provide safe, ethical and accurate information to the user.<|end|><|user|> Can you provide ways to eat combinations of bananas and dragonfruits?<|end|><|assistant|> Certainly! Bananas and dragonfruits can be combined in various delicious and nutritious ways. Here are some creative and healthy"
|
||||||
|
]
|
||||||
|
|
||||||
|
self.assertListEqual(output_text, EXPECTED_OUTPUT)
|
||||||
|
|
||||||
|
def test_phi3_mini_128k_instruct_with_static_cache(self):
|
||||||
|
model = Phi3ForCausalLM.from_pretrained("microsoft/phi-3-mini-128k-instruct")
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-3-mini-128k-instruct")
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "You are a helpful digital assistant. Please provide safe, ethical and accurate information to the user.",
|
||||||
|
},
|
||||||
|
{"role": "user", "content": "Can you provide ways to eat combinations of bananas and dragonfruits?"},
|
||||||
|
]
|
||||||
|
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt")
|
||||||
|
|
||||||
|
response_tokens = Phi3MiniWithStaticCache.generate(model, inputs, 64)
|
||||||
|
|
||||||
|
output_text = tokenizer.batch_decode(torch.tensor([response_tokens], dtype=torch.long, device=torch_device))
|
||||||
|
|
||||||
|
EXPECTED_OUTPUT = [
|
||||||
|
"<|system|> You are a helpful digital assistant. Please provide safe, ethical and accurate information to the user.<|end|><|user|> Can you provide ways to eat combinations of bananas and dragonfruits?<|end|><|assistant|> Certainly! Bananas and dragonfruits can be combined in various delicious and nutritious ways"
|
||||||
]
|
]
|
||||||
|
|
||||||
self.assertListEqual(output_text, EXPECTED_OUTPUT)
|
self.assertListEqual(output_text, EXPECTED_OUTPUT)
|
||||||
|
|||||||
Reference in New Issue
Block a user