Update expected values (after switching to A10) - part 7 (#39218)
* fix * fix * fix --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -165,7 +165,8 @@ class Cohere2IntegrationTest(unittest.TestCase):
|
||||
EXPECTED_TEXTS = Expectations(
|
||||
{
|
||||
("xpu", 3): ["<BOS_TOKEN>Hello I am doing a project for my school and I need to create a website for a fictional company. I have the", "<PAD><PAD><BOS_TOKEN>Hi today I'm going to show you how to make a simple and easy to make a chocolate cake.\n"],
|
||||
("cuda", 7): ["<BOS_TOKEN>Hello I am doing a project for a school assignment and I need to create a website for a fictional company. I have", "<PAD><PAD><BOS_TOKEN>Hi today I'm going to show you how to make a simple and easy to make a chocolate cake.\n",],
|
||||
(None, None): ["<BOS_TOKEN>Hello I am doing a project for a school assignment and I need to create a website for a fictional company. I have", "<PAD><PAD><BOS_TOKEN>Hi today I'm going to show you how to make a simple and easy to make a chocolate cake.\n"],
|
||||
("cuda", 8): ['<BOS_TOKEN>Hello I am doing a project for my school and I need to create a website for a fictional company. I have the', "<PAD><PAD><BOS_TOKEN>Hi today I'm going to show you how to make a simple and easy to make a chocolate cake.\n"],
|
||||
}
|
||||
)
|
||||
EXPECTED_TEXT = EXPECTED_TEXTS.get_expectation()
|
||||
@@ -238,7 +239,8 @@ class Cohere2IntegrationTest(unittest.TestCase):
|
||||
EXPECTED_TEXT_COMPLETIONS = Expectations(
|
||||
{
|
||||
("xpu", 3): ["Hello I am doing a project for a friend and I am stuck on a few things. I have a 2004 Ford F-"],
|
||||
("cuda", 7): ["Hello I am doing a project on the effects of social media on mental health. I have a few questions. 1. What is the relationship",],
|
||||
(None, None): ["Hello I am doing a project on the effects of social media on mental health. I have a few questions. 1. What is the relationship"],
|
||||
("cuda", 8): ['Hello I am doing a project for a friend and I am stuck on a few things. I have a 2004 Ford F-'],
|
||||
}
|
||||
)
|
||||
EXPECTED_TEXT_COMPLETION = EXPECTED_TEXT_COMPLETIONS.get_expectation()
|
||||
@@ -290,24 +292,31 @@ class Cohere2IntegrationTest(unittest.TestCase):
|
||||
if attn_implementation == "flash_attention_2" and not is_flash_attn_2_available():
|
||||
self.skipTest("FlashAttention2 is required for this test.")
|
||||
|
||||
# TODO: if we can specify not to compile when `flex` attention is used?
|
||||
if attn_implementation == "flex_attention":
|
||||
self.skipTest(
|
||||
"Flex attention will compile (see `compile_friendly_flex_attention`) which causes triton issue."
|
||||
)
|
||||
|
||||
if torch_device == "xpu" and attn_implementation == "flash_attention_2":
|
||||
self.skipTest(reason="Intel XPU doesn't support falsh_attention_2 as of now.")
|
||||
|
||||
model_id = "CohereForAI/c4ai-command-r7b-12-2024"
|
||||
EXPECTED_COMPLETIONS = [
|
||||
" the mountains, the lakes, the rivers, the waterfalls, the waterfalls, the waterfalls, the waterfalls",
|
||||
" the mountains, the lakes, the rivers, the forests, the trees, the birds, the animals",
|
||||
", green, yellow, orange, purple, pink, brown, black, white, grey, silver",
|
||||
]
|
||||
|
||||
input_text = [
|
||||
"This is a nice place. " * 800 + "I really enjoy the scenery,", # This is larger than 4096 tokens
|
||||
"This is a nice place. " * 200 + "I really enjoy the scenery,", # This is larger than 1024 tokens
|
||||
"A list of colors: red, blue", # This will almost all be padding tokens
|
||||
]
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id, padding="left")
|
||||
inputs = tokenizer(input_text, padding=True, return_tensors="pt").to(torch_device)
|
||||
|
||||
# We use `sliding_window=1024` instead of the origin value `4096` in the config to avoid GPU OOM
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id, attn_implementation=attn_implementation, torch_dtype=torch.float16
|
||||
model_id, attn_implementation=attn_implementation, torch_dtype=torch.float16, sliding_window=1024
|
||||
).to(torch_device)
|
||||
|
||||
# Make sure prefill is larger than sliding window
|
||||
|
||||
Reference in New Issue
Block a user