[CodeLlama] Add support for CodeLlama (#25740)
* add all * Revert "Delete .github directory" This reverts commit 9b0ff7b052e2b20b629a26fb13606b78a42944d1. * make conversion script backward compatible * fixup * more styling * copy to llama changes * fix repo consistency * nits * document correct classes * updates * more fixes * nits * update auto mappings * add readmes * smallupdates * llama-code replace with llama_code * make fixup * updates to the testsing suite * fix fast nits * more small fixes * fix decode * fix template processing * properly reset the normalizer * nits processor * tokenization tests pass * styling * last tests * additional nits * one test is left * nits Co-authored-by faabian <faabian@users.noreply.github.com> * update failing test * fixup * remove decode infilling users should handle it on their onw after generation, padding can be a problem * update * make test slow and more meaningfull * fixup * doc update * fixup * Apply suggestions from code review * add kwargs doc * tokenizer requires `requires_backend` * type requires_backends * CodeLlama instead of LlamaCode * more name cahnges * nits * make doctests happy * small pipeline nits * last nit * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * update * add codellama to toctree --------- Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -20,7 +20,7 @@ import unittest
|
||||
from parameterized import parameterized
|
||||
|
||||
from transformers import LlamaConfig, is_torch_available, set_seed
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
from transformers.testing_utils import require_torch, require_torch_gpu, slow, torch_device
|
||||
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
@@ -31,7 +31,13 @@ from ...test_pipeline_mixin import PipelineTesterMixin
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel, LlamaTokenizer
|
||||
from transformers import (
|
||||
CodeLlamaTokenizer,
|
||||
LlamaForCausalLM,
|
||||
LlamaForSequenceClassification,
|
||||
LlamaModel,
|
||||
LlamaTokenizer,
|
||||
)
|
||||
|
||||
|
||||
class LlamaModelTester:
|
||||
@@ -450,3 +456,85 @@ class LlamaIntegrationTest(unittest.TestCase):
|
||||
generated_ids = model.generate(input_ids, max_new_tokens=64, top_p=None, temperature=1, do_sample=False)
|
||||
text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
|
||||
self.assertEqual(EXPECTED_TEXT_COMPLETION, text)
|
||||
|
||||
|
||||
@require_torch
|
||||
class CodeLlamaIntegrationTest(unittest.TestCase):
|
||||
PROMPTS = [
|
||||
'''def remove_non_ascii(s: str) -> str:
|
||||
""" <FILL_ME>
|
||||
return result
|
||||
''',
|
||||
"""# Installation instructions:
|
||||
```bash
|
||||
<FILL_ME>
|
||||
```
|
||||
This downloads the LLaMA inference code and installs the repository as a local pip package.
|
||||
""",
|
||||
"""class InterfaceManagerFactory(AbstractManagerFactory):
|
||||
def __init__(<FILL_ME>
|
||||
def main():
|
||||
factory = InterfaceManagerFactory(start=datetime.now())
|
||||
managers = []
|
||||
for i in range(10):
|
||||
managers.append(factory.build(id=i))
|
||||
""",
|
||||
"""/-- A quasi-prefunctoid is 1-connected iff all its etalisations are 1-connected. -/
|
||||
theorem connected_iff_etalisation [C D : precategoroid] (P : quasi_prefunctoid C D) :
|
||||
π₁ P = 0 ↔ <FILL_ME> = 0 :=
|
||||
begin
|
||||
split,
|
||||
{ intros h f,
|
||||
rw pi_1_etalisation at h,
|
||||
simp [h],
|
||||
refl
|
||||
},
|
||||
{ intro h,
|
||||
have := @quasi_adjoint C D P,
|
||||
simp [←pi_1_etalisation, this, h],
|
||||
refl
|
||||
}
|
||||
end
|
||||
""",
|
||||
]
|
||||
|
||||
@require_torch_gpu
|
||||
@slow
|
||||
def test_model_7b_logits(self):
|
||||
model = LlamaForCausalLM.from_pretrained("codellama/CodeLlama-7b-hf").to(torch_device)
|
||||
tokenizer = CodeLlamaTokenizer.from_pretrained("codellama/CodeLlama-7b-hf")
|
||||
# Tokenize and prepare for the model a list of sequences or a list of pairs of sequences.
|
||||
# meaning by default this supports passing splitted list of inputs
|
||||
processed_text = tokenizer.batch_decode(tokenizer(self.PROMPTS)["input_ids"], add_special_tokens=False)
|
||||
# fmt: off
|
||||
EXPECTED_TEXT = [
|
||||
'<s> <PRE> def remove_non_ascii(s: str) -> str:\n """ <SUF>\n return result\n <MID>',
|
||||
'<s> <PRE> # Installation instructions:\n ```bash\n <SUF>\n ```\nThis downloads the LLaMA inference code and installs the repository as a local pip package.\n <MID>',
|
||||
'<s> <PRE> class InterfaceManagerFactory(AbstractManagerFactory):\n def __init__( <SUF>\ndef main():\n factory = InterfaceManagerFactory(start=datetime.now())\n managers = []\n for i in range(10):\n managers.append(factory.build(id=i))\n <MID>',
|
||||
'<s> <PRE> /-- A quasi-prefunctoid is 1-connected iff all its etalisations are 1-connected. -/\ntheorem connected_iff_etalisation [C D : precategoroid] (P : quasi_prefunctoid C D) :\nπ₁ P = 0 ↔ <SUF> = 0 :=\nbegin\nsplit,\n{ intros h f,\n rw pi_1_etalisation at h,\n simp [h],\n refl\n},\n{ intro h,\n have := @quasi_adjoint C D P,\n simp [←pi_1_etalisation, this, h],\n refl\n}\nend\n <MID>'
|
||||
]
|
||||
# fmt: on
|
||||
self.assertEqual(processed_text, EXPECTED_TEXT)
|
||||
processed_text_suffix_first = tokenizer.batch_decode(
|
||||
tokenizer(self.PROMPTS, suffix_first=True, add_special_tokens=False)["input_ids"]
|
||||
)
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_TEXT = [
|
||||
'<PRE> <SUF>\n return result\n <MID> def remove_non_ascii(s: str) -> str:\n """ ',
|
||||
'<PRE> <SUF>\n ```\nThis downloads the LLaMA inference code and installs the repository as a local pip package.\n <MID> # Installation instructions:\n ```bash\n',
|
||||
'<PRE> <SUF>\ndef main():\n factory = InterfaceManagerFactory(start=datetime.now())\n managers = []\n for i in range(10):\n managers.append(factory.build(id=i))\n <MID> class InterfaceManagerFactory(AbstractManagerFactory):\n def __init__(',
|
||||
'<PRE> <SUF> = 0 :=\nbegin\nsplit,\n{ intros h f,\n rw pi_1_etalisation at h,\n simp [h],\n refl\n},\n{ intro h,\n have := @quasi_adjoint C D P,\n simp [←pi_1_etalisation, this, h],\n refl\n}\nend\n <MID> /-- A quasi-prefunctoid is 1-connected iff all its etalisations are 1-connected. -/\ntheorem connected_iff_etalisation [C D : precategoroid] (P : quasi_prefunctoid C D) :\nπ₁ P = 0 ↔ '
|
||||
]
|
||||
EXPECTED_IDS = torch.tensor([[ 1, 32007, 822, 3349, 29918, 5464, 29918, 294, 18869, 29898,29879, 29901, 851, 29897, 1599, 851, 29901, 13, 1678, 9995, 29871, 32008, 13, 1678, 736, 1121, 13, 32009, 15941, 1661, 29899, 28599, 2687, 4890, 515, 263, 1347, 29889, 13, 13, 1678, 826, 3174, 29901, 13, 4706, 269, 29901, 450, 1347, 304, 3349, 1661, 29899, 28599, 2687, 4890, 515, 29889, 13, 13, 1678, 16969, 29901, 13, 4706, 450, 1347, 411, 1661, 29899, 28599, 2687, 4890, 6206, 29889, 13, 1678, 9995, 13, 1678, 1121, 353, 5124, 13, 1678, 363, 274, 297, 269, 29901, 13, 4706, 565, 4356, 29898, 29883, 29897, 529, 29871, 29896, 29906, 29947, 29901, 13, 9651, 1121, 4619, 274, 32010, 2]])
|
||||
# fmt: on
|
||||
self.assertEqual(processed_text_suffix_first, EXPECTED_TEXT)
|
||||
input_ids = tokenizer(self.PROMPTS[0], return_tensors="pt")["input_ids"]
|
||||
generated_ids = model.generate(input_ids.to(torch_device), max_new_tokens=128)
|
||||
torch.testing.assert_close(generated_ids, EXPECTED_IDS)
|
||||
|
||||
EXPECTED_INFILLING = [
|
||||
'<s> <PRE> def remove_non_ascii(s: str) -> str:\n """ <SUF>\n return result\n <MID>Remove non-ASCII characters from a string.\n\n Args:\n s: The string to remove non-ASCII characters from.\n\n Returns:\n The string with non-ASCII characters removed.\n """\n result = ""\n for c in s:\n if ord(c) < 128:\n result += c <EOT></s>'
|
||||
]
|
||||
infilling = tokenizer.batch_decode(generated_ids)
|
||||
self.assertEqual(infilling, EXPECTED_INFILLING)
|
||||
|
||||
Reference in New Issue
Block a user