Llama: fix batched generation (#29109)
This commit is contained in:
@@ -293,7 +293,7 @@ class CacheIntegrationTest(unittest.TestCase):
|
||||
@parameterized.expand(["eager", "sdpa", "flash_attention_2"])
|
||||
def test_static_cache_greedy_sampling_pad_left(self, attn_implementation):
|
||||
EXPECTED_GENERATION = [
|
||||
"The best color is the one that complements the subject you are photograph",
|
||||
"The best color is the one that complements the skin tone of the",
|
||||
"We should not undermind the issues at hand.\nWe should not undermind the issues",
|
||||
]
|
||||
|
||||
@@ -333,18 +333,18 @@ class CacheIntegrationTest(unittest.TestCase):
|
||||
@parameterized.expand(["eager", "sdpa", "flash_attention_2"])
|
||||
def test_static_cache_greedy_sampling_pad_right(self, attn_implementation):
|
||||
EXPECTED_GENERATION = [
|
||||
"The best color is\n\n\n\n\n\n\n\n\n\n",
|
||||
"We should not undermind the issues at hand, but address them head on.\nI think",
|
||||
"The best color isЋ the one that complements the skin tone of",
|
||||
"We should not undermind the issues at hand.\nWe should not undermind the issues",
|
||||
]
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"NousResearch/Llama-2-7b-chat-hf", padding_side="left", pad_token="<s>"
|
||||
"NousResearch/Llama-2-7b-chat-hf", padding_side="right", pad_token="<s>"
|
||||
)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"NousResearch/Llama-2-7b-chat-hf",
|
||||
torch_dtype=torch.bfloat16,
|
||||
attn_implementation=attn_implementation,
|
||||
).to("cuda:1")
|
||||
).to(torch_device)
|
||||
inputs = tokenizer(
|
||||
["The best color is", "We should not undermind the issues at hand"], padding=True, return_tensors="pt"
|
||||
).to(model.device)
|
||||
|
||||
Reference in New Issue
Block a user