Code-llama-nit (#26300)
* fix encoding when the fill token is None * add tests and edge cases * fiuxp * Update tests/models/code_llama/test_tokenization_code_llama.py
This commit is contained in:
@@ -303,7 +303,7 @@ class CodeLlamaTokenizerFast(PreTrainedTokenizerFast):
|
|||||||
def encode_plus(self, text, text_pair=None, suffix_first=False, add_special_tokens=True, **kwargs):
|
def encode_plus(self, text, text_pair=None, suffix_first=False, add_special_tokens=True, **kwargs):
|
||||||
# hack to make sure the input is pre-process but outside rust
|
# hack to make sure the input is pre-process but outside rust
|
||||||
text_pair = kwargs.pop("suffix", text_pair)
|
text_pair = kwargs.pop("suffix", text_pair)
|
||||||
if self.fill_token in text and text_pair is None:
|
if self.fill_token is not None and self.fill_token in text and text_pair is None:
|
||||||
text, text_pair = text.split(self.fill_token)
|
text, text_pair = text.split(self.fill_token)
|
||||||
|
|
||||||
if text_pair is None or len(text_pair) < 1:
|
if text_pair is None or len(text_pair) < 1:
|
||||||
|
|||||||
@@ -559,6 +559,24 @@ class LlamaIntegrationTest(unittest.TestCase):
|
|||||||
decoded_tokens = tokenizer.decode(input_ids)
|
decoded_tokens = tokenizer.decode(input_ids)
|
||||||
self.assertEqual(decoded_tokens, " <s> Hello<s> how")
|
self.assertEqual(decoded_tokens, " <s> Hello<s> how")
|
||||||
|
|
||||||
|
def test_fill_token(self):
|
||||||
|
tokenizer = CodeLlamaTokenizerFast.from_pretrained(
|
||||||
|
"codellama/CodeLlama-7b-hf", fill_token=None, prefix_token=None, suffix_token=None, middle_token=None
|
||||||
|
)
|
||||||
|
tokenizer.encode_plus("Hey how are you").input_ids
|
||||||
|
tokenizer.fill_token = "<FILL_ME>"
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
tokenizer.encode("Hey how <FILL_ME> are you")
|
||||||
|
tokenizer.encode_plus("Hey how <FILL_ME> are you", "mne too")
|
||||||
|
tokenizer.tokenize("Hey how are you", "mne too")
|
||||||
|
|
||||||
|
tokenizer = CodeLlamaTokenizerFast.from_pretrained(
|
||||||
|
"codellama/CodeLlama-7b-hf", revision="3773f63b4511b9e47a9a7ffc765eed7eb0169486"
|
||||||
|
)
|
||||||
|
tokenizer.encode("Hey how <FILL_ME> are you")
|
||||||
|
tokenizer.encode_plus("Hey how <FILL_ME> are you", "mne too")
|
||||||
|
tokenizer.tokenize("Hey how are you", "mne too")
|
||||||
|
|
||||||
def test_spm_edge_cases(self):
|
def test_spm_edge_cases(self):
|
||||||
# the word inform should be split as ['in', 'form']
|
# the word inform should be split as ['in', 'form']
|
||||||
tokenizer = CodeLlamaTokenizer.from_pretrained("codellama/CodeLlama-7b-hf", legacy=False)
|
tokenizer = CodeLlamaTokenizer.from_pretrained("codellama/CodeLlama-7b-hf", legacy=False)
|
||||||
|
|||||||
Reference in New Issue
Block a user