[Llama2] Add support for Llama 2 (#24891)
* add llama * add other readmes * update padding id in readme * add link to paper * fix paths and tokenizer * more nits * styling * fit operation in 2 lines when possible * nits * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * add form * update reademe * update readme, we don't have a default pad token * update test and tokenization * LLaMA instead of Llama * nits * add expected text * add greeedy output * styling * Update src/transformers/models/llama/modeling_llama.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * sequential device map * skip relevant changes --------- Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -20,7 +20,7 @@ import unittest
|
||||
from parameterized import parameterized
|
||||
|
||||
from transformers import LlamaConfig, is_torch_available, set_seed
|
||||
from transformers.testing_utils import require_torch, torch_device
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
@@ -31,7 +31,7 @@ from ...test_pipeline_mixin import PipelineTesterMixin
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel
|
||||
from transformers import LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel, LlamaTokenizer
|
||||
|
||||
|
||||
class LlamaModelTester:
|
||||
@@ -365,3 +365,85 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
||||
|
||||
# The output should be different for long inputs
|
||||
self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5))
|
||||
|
||||
|
||||
@require_torch
|
||||
class LlamaIntegrationTest(unittest.TestCase):
|
||||
@unittest.skip("Logits are not exactly the same, once we fix the instabalities somehow, will update!")
|
||||
@slow
|
||||
def test_model_7b_logits(self):
|
||||
input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338]
|
||||
model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", device_map="auto")
|
||||
out = model(torch.tensor([input_ids]))
|
||||
# Expected mean on dim = -1
|
||||
EXPECTED_MEAN = torch.tensor([[-6.6550, -4.1227, -4.9859, -3.2406, 0.8262, -3.0033, 1.2964, -3.3699]])
|
||||
torch.testing.assert_close(out.mean(-1), EXPECTED_MEAN, atol=1e-2, rtol=1e-2)
|
||||
# slicing logits[0, 0, 0:30]
|
||||
# fmt: off
|
||||
EXPECTED_SLICE = torch.tensor([-12.8281, -7.4453, -0.4639, -8.0625, -7.2500, -8.0000, -6.4883, -7.7695, -7.8438, -7.0312, -6.2188, -7.1328, -1.8496, 1.9961, -8.6250, -6.7227, -12.8281, -6.9492, -7.0742, -7.7852, -7.5820, -7.9062, -6.9375, -7.9805, -8.3438, -8.1562, -8.0469, -7.6250, -7.7422, -7.3398,])
|
||||
# fmt: on
|
||||
torch.testing.assert_close(out[0, 0, :30], EXPECTED_SLICE, atol=1e-5, rtol=1e-5)
|
||||
|
||||
@unittest.skip("Logits are not exactly the same, once we fix the instabalities somehow, will update!")
|
||||
@slow
|
||||
def test_model_13b_logits(self):
|
||||
input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338]
|
||||
model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-13b-hf", device_map="auto")
|
||||
out = model(torch.tensor(input_ids))
|
||||
# Expected mean on dim = -1
|
||||
EXPECTED_MEAN = torch.tensor([[-2.0622, -1.2794, -1.1638, -0.9788, -1.4603, -1.0238, -1.7893, -1.4411]])
|
||||
torch.testing.assert_close(out.mean(-1), EXPECTED_MEAN, atol=1e-2, rtol=1e-2)
|
||||
# slicing logits[0, 0, 0:30]
|
||||
# fmt: off
|
||||
EXPECTED_SLICE = torch.tensor([-8.1406, -8.0547, 2.7461, -1.2344, -0.1448, -1.8262, -1.0020, -1.8154, -1.6895, -1.8516, -2.3574, -0.9277, 3.7598, 6.5742, -1.2998, -0.1177, -8.1406, -2.9688, -2.9199, -3.1699, -3.5254, -2.3555, -2.7988, -3.4141, -2.8262, -4.5195, -3.3379, -3.3164, -2.7832, -3.0273])
|
||||
# fmt: on
|
||||
torch.testing.assert_close(out[0, 0, :30], EXPECTED_SLICE, atol=1e-5, rtol=1e-5)
|
||||
|
||||
@unittest.skip("Logits are not exactly the same, once we fix the instabalities somehow, will update!")
|
||||
@slow
|
||||
def test_model_13bf_logits(self):
|
||||
input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338]
|
||||
model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-13b-chat-hf", device_map="auto")
|
||||
out = model(torch.tensor(input_ids))
|
||||
# Expected mean on dim = -1
|
||||
EXPECTED_MEAN = torch.tensor([[-0.8562, -1.8520, -0.7551, -0.4162, -1.5161, -1.2038, -2.4823, -2.3254]])
|
||||
torch.testing.assert_close(out.mean(-1), EXPECTED_MEAN, atol=1e-2, rtol=1e-2)
|
||||
# slicing logits[0, 0, 0:30]
|
||||
# fmt: off
|
||||
EXPECTED_SLICE = torch.tensor([-2.2227, 4.8828, 0.9023, -0.4578, -0.7871, -0.1033, -0.6221, -0.5786, -0.7803, -1.0674, -1.2920, -0.1570, 0.8008, 2.0723, -0.9497, 0.2771, -2.2227, -0.7612, -1.4346, -1.2061, -1.6426, -0.3000, -0.7139, -1.1934, -1.8691, -1.6973, -1.5947, -1.2705, -0.3523, -0.5513])
|
||||
# fmt: on
|
||||
torch.testing.assert_close(out.mean(-1), EXPECTED_SLICE, atol=1e-2, rtol=1e-2)
|
||||
|
||||
@unittest.skip(
|
||||
"Logits are not exactly the same, once we fix the instabalities somehow, will update! Also it is gonna be a `too_slow` test"
|
||||
)
|
||||
@slow
|
||||
def test_model_70b_logits(self):
|
||||
input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338]
|
||||
model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-70b-hf", device_map="auto")
|
||||
out = model(torch.tensor(input_ids))
|
||||
|
||||
EXPECTED_MEAN = torch.tensor(
|
||||
[[-4.2327, -3.3360, -4.6665, -4.7631, -1.8180, -3.4170, -1.4211, -3.1810]], dtype=torch.float32
|
||||
)
|
||||
torch.testing.assert_close(out.mean(-1), EXPECTED_MEAN, atol=1e-2, rtol=1e-2)
|
||||
# fmt: off
|
||||
EXPECTED_SLICE = torch.tensor([-9.4922, -3.9551, 1.7998, -5.6758, -5.1055, -5.8984, -4.8320, -6.8086, -6.5391, -5.6172, -5.5820, -5.5352, 1.7881, 3.6289, -6.5117, -3.4785, -9.5000, -6.0352, -6.8125, -6.0195, -6.6836, -5.4727, -6.2812, -6.0391, -7.3398, -7.4297, -7.4844, -6.5820, -5.8789, -5.5312])
|
||||
# fmt: on
|
||||
torch.testing.assert_close(out[0, 0, :30], EXPECTED_SLICE, atol=1e-5, rtol=1e-5)
|
||||
|
||||
@unittest.skip("Model is curently gated")
|
||||
@slow
|
||||
def test_model_13b_greedy_generation(self):
|
||||
EXPECTED_TEXT_COMPLETION = """Simply put, the theory of relativity states that 1) the laws of physics are the same everywhere in the universe and 2) the passage of time and the length of objects can vary depending on the observer\'s frame of reference.\n\nThe first part of the theory, that the laws of physics are the same everywhere, is known as the "princi"""
|
||||
prompt = "Simply put, the theory of relativity states that "
|
||||
tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-13b-chat-hf")
|
||||
input_ids = tokenizer.encode(prompt, return_tensors="pt")
|
||||
model = LlamaForCausalLM.from_pretrained(
|
||||
"meta-llama/Llama-2-13b-chat-hf", device_map="sequential", use_safetensors=False
|
||||
)
|
||||
|
||||
# greedy generation outputs
|
||||
generated_ids = model.generate(input_ids, max_new_tokens=64, top_p=None, temperature=1, do_sample=False)
|
||||
text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
|
||||
self.assertEqual(EXPECTED_TEXT_COMPLETION, text)
|
||||
|
||||
@@ -216,6 +216,41 @@ class ConversationalPipelineTests(unittest.TestCase):
|
||||
inputs["input_ids"].tolist(), [[31373, 50256, 17250, 612, 0, 50256, 4919, 389, 345, 5633, 50256]]
|
||||
)
|
||||
|
||||
@unittest.skip("Model is curently gated")
|
||||
@require_torch
|
||||
@slow
|
||||
def test_integration_torch_conversation_llama2_input_ids(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
|
||||
|
||||
conversation = Conversation(
|
||||
"What is so great about #1?",
|
||||
past_user_inputs=["I am going to Paris, what should I see?"],
|
||||
generated_responses=[
|
||||
"""\
|
||||
Paris, the capital of France, is known for its stunning architecture, art museums, historical landmarks, and romantic atmosphere. Here are some of the top attractions to see in Paris:
|
||||
|
||||
1. The Eiffel Tower: The iconic Eiffel Tower is one of the most recognizable landmarks in the world and offers breathtaking views of the city.
|
||||
2. The Louvre Museum: The Louvre is one of the world's largest and most famous museums, housing an impressive collection of art and artifacts, including the Mona Lisa.
|
||||
3. Notre-Dame Cathedral: This beautiful cathedral is one of the most famous landmarks in Paris and is known for its Gothic architecture and stunning stained glass windows.
|
||||
|
||||
These are just a few of the many attractions that Paris has to offer. With so much to see and do, it's no wonder that Paris is one of the most popular tourist destinations in the world."""
|
||||
],
|
||||
)
|
||||
inputs = tokenizer._build_conversation_input_ids(conversation)
|
||||
# fmt: off
|
||||
EXPECTED_INPUTS_IDS = [ 1, 518, 25580, 29962, 3532, 14816, 29903, 6778, 13, 3492, 526, 263, 8444, 29892, 3390, 1319, 322, 15993, 20255, 29889, 29849, 1234, 408, 1371, 3730, 408, 1950, 29892, 1550, 1641, 9109, 29889, 29871, 3575, 6089, 881, 451, 3160, 738, 10311, 1319, 29892, 443, 621, 936, 29892, 11021, 391, 29892, 7916, 391, 29892, 304, 27375, 29892, 18215, 29892, 470, 27302, 2793, 29889, 3529, 9801, 393, 596, 20890, 526, 5374, 635, 443, 5365, 1463, 322, 6374, 297, 5469, 29889, 13, 13, 3644, 263, 1139, 947, 451, 1207, 738, 4060, 29892, 470, 338, 451, 2114, 1474, 16165, 261, 296, 29892, 5649, 2020, 2012, 310, 22862, 1554, 451, 1959, 29889, 960, 366, 1016, 29915, 29873, 1073, 278, 1234, 304, 263, 1139, 29892, 3113, 1016, 29915, 29873, 6232, 2089, 2472, 29889, 13, 29966, 829, 14816, 29903, 6778, 13, 13, 29902, 626, 2675, 304, 3681, 29892, 825, 881, 306, 1074, 29973, 518, 29914, 25580, 29962, 3681, 29892, 278, 7483, 310, 3444, 29892, 338, 2998, 363, 967, 380, 27389, 11258, 29892, 1616, 19133, 29879, 29892, 15839, 2982, 22848, 29892, 322, 6017, 7716, 25005, 29889, 2266, 526, 777, 310, 278, 2246, 19650, 1953, 304, 1074, 297, 3681, 29901, 13, 13, 29896, 29889, 450, 382, 2593, 295, 23615, 29901, 450, 9849, 293, 382, 2593, 295, 23615, 338, 697, 310, 278, 1556, 5936, 13902, 2982, 22848, 297, 278, 3186, 322, 16688, 2078, 271, 400, 5086, 8386, 310, 278, 4272, 29889, 13, 29906, 29889, 450, 4562, 12675, 6838, 29901, 450, 4562, 12675, 338, 697, 310, 278, 3186, 29915, 29879, 10150, 322, 1556, 13834, 19133, 29879, 29892, 27261, 385, 21210, 573, 4333, 310, 1616, 322, 24238, 29879, 29892, 3704, 278, 2598, 29874, 29420, 29889, 13, 29941, 29889, 24337, 29899, 29928, 420, 315, 21471, 29901, 910, 9560, 274, 21471, 338, 697, 310, 278, 1556, 13834, 2982, 22848, 297, 3681, 322, 338, 2998, 363, 967, 22883, 293, 11258, 322, 380, 27389, 380, 7114, 12917, 5417, 29889, 13, 13, 1349, 968, 526, 925, 263, 2846, 310, 278, 1784, 19650, 1953, 393, 3681, 756, 304, 5957, 29889, 2973, 577, 1568, 304, 1074, 322, 437, 29892, 372, 29915, 29879, 694, 4997, 393, 3681, 338, 697, 310, 278, 1556, 5972, 6282, 391, 15422, 800, 297, 278, 3186, 29889, 29871, 2, 1, 518, 25580, 29962, 1724, 338, 577, 2107, 1048, 396, 29896, 29973, 518, 29914, 25580, 29962]
|
||||
# fmt: on
|
||||
self.assertEqual(inputs, EXPECTED_INPUTS_IDS)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
|
||||
conversation_agent = ConversationalPipeline(model=model, tokenizer=tokenizer)
|
||||
EXPECTED_TEXT = "what topic you want to focus on and create content around it. This will help you stand out from other creators and attract a specific audience.\n\nStep 2: Set Up Your Channel\nCreate your YouTube account and customize your channel with your branding and logo. Make sure your channel name and profile picture are consistent with your niche.\n\nStep 3: Plan Your Content\nDevelop a content strategy that includes the type of content you want to create, how often you will post, and when you will post. Consider creating a content calendar to help you stay organized.\n\nStep 4: Invest in Quality Equipment\nInvest in good quality camera and microphone equipment to ensure your videos look and sound professional. You don't need to break the bank, but investing in good equipment will make a big difference in the quality of your videos.\n\nStep 5: Optimize Your Videos for Search\nUse keywords in your video titles, descriptions, and tags to help people find your videos when they search for topics related to your niche"
|
||||
conversation = Conversation(
|
||||
"<<SYS>>\n Only answer with emojis, and charades\n<</SYS>>\n\nHow can I build a house in 10 steps?"
|
||||
)
|
||||
result = conversation_agent(conversation)
|
||||
self.assertEqual(result.generated_responses[-1], EXPECTED_TEXT)
|
||||
|
||||
@require_torch
|
||||
@slow
|
||||
def test_integration_torch_conversation_blenderbot_400M_input_ids(self):
|
||||
|
||||
Reference in New Issue
Block a user