[LlamaTokenizerFast] Refactor default llama (#28881)
* push legacy to fast as well * super strange * Update src/transformers/convert_slow_tokenizer.py * make sure we are BC * fix Llama test * nit * revert * more test * style * update * small update w.r.t tokenizers * nit * don't split * lol * add a test for `add_prefix_space=False` * fix gemma tokenizer as well * update * fix gemma * nicer failures * fixup * update * fix the example for legacy = False * use `huggyllama/llama-7b` for the PR doctest * nit * use from_slow * fix llama
This commit is contained in:
@@ -543,8 +543,15 @@ class LlamaIntegrationTest(unittest.TestCase):
|
||||
|
||||
def test_special_token_special_word(self):
|
||||
# the word inform should be split as ['in', 'form']
|
||||
tokenizer = LlamaTokenizer.from_pretrained("huggyllama/llama-7b", legacy=False)
|
||||
tokenizer = LlamaTokenizerFast.from_pretrained("huggyllama/llama-7b", legacy=False, from_slow=True)
|
||||
tokenizer.add_tokens([AddedToken("<REPR_END>", rstrip=True, lstrip=True)], special_tokens=False)
|
||||
|
||||
example_inputs = tokenizer.tokenize("<REPR_END>inform<s>. Hey. .")
|
||||
self.assertEqual(example_inputs, ["<REPR_END>", "in", "form", "<s>", ".", "▁Hey", ".", "▁▁▁▁▁▁", "▁."])
|
||||
|
||||
# Make sure dummy space is added if it is indeed the first word
|
||||
example_inputs = tokenizer.tokenize("inform<s>. Hey. .")
|
||||
self.assertEqual(example_inputs, ["▁inform", "<s>", ".", "▁Hey", ".", "▁▁▁▁▁▁", "▁."])
|
||||
out1 = tokenizer.decode(
|
||||
tokenizer.encode("<REPR_END>inform", add_special_tokens=False), spaces_between_special_tokens=False
|
||||
)
|
||||
@@ -553,12 +560,12 @@ class LlamaIntegrationTest(unittest.TestCase):
|
||||
tokenizer.encode("<REPR_END>inform", add_special_tokens=False), spaces_between_special_tokens=True
|
||||
)
|
||||
# decoding strips the added prefix space.
|
||||
self.assertEqual(out2, "<REPR_END> inform")
|
||||
self.assertEqual(out2, "<REPR_END>inform")
|
||||
input_ids = tokenizer.encode("<REPR_END>inform", add_special_tokens=False)
|
||||
self.assertEqual(input_ids, [29871, 32000, 262, 689]) # 29871 is the spiece underline, '▁' added as it should
|
||||
self.assertEqual(input_ids, [32000, 262, 689]) # 29871 is the spiece underline, '▁' added as it should
|
||||
|
||||
out2 = tokenizer.decode(
|
||||
tokenizer.encode(" <REPR_END> inform", add_special_tokens=False), spaces_between_special_tokens=False
|
||||
tokenizer.encode(" <REPR_END>inform", add_special_tokens=False), spaces_between_special_tokens=False
|
||||
)
|
||||
# TODO @ArthurZ currently we strip left and right, so this will not keep the spaces
|
||||
self.assertEqual(out2, "<REPR_END>inform")
|
||||
@@ -575,11 +582,11 @@ class LlamaIntegrationTest(unittest.TestCase):
|
||||
|
||||
# Let's make sure that if there are any spaces, we don't remove them!
|
||||
input_ids = tokenizer.encode(" <s> Hello<s> how", add_special_tokens=False)
|
||||
self.assertEqual(input_ids, [259, 1, 15043, 1, 920])
|
||||
self.assertEqual(input_ids, [29871, 1, 15043, 1, 920])
|
||||
tokens = tokenizer.tokenize(" <s> Hello<s> how", add_special_tokens=False)
|
||||
self.assertEqual(tokens, ["▁▁", "<s>", "▁Hello", "<s>", "▁how"])
|
||||
self.assertEqual(tokens, ["▁", "<s>", "▁Hello", "<s>", "▁how"])
|
||||
decoded_tokens = tokenizer.decode(input_ids)
|
||||
self.assertEqual(decoded_tokens, " <s> Hello<s> how")
|
||||
self.assertEqual(decoded_tokens, "<s> Hello<s> how")
|
||||
|
||||
# Let's make sure the space is preserved
|
||||
input_ids = tokenizer.encode("hello", add_special_tokens=True)
|
||||
@@ -594,6 +601,63 @@ class LlamaIntegrationTest(unittest.TestCase):
|
||||
decoded_tokens = tokenizer.decode(input_ids)
|
||||
self.assertEqual(decoded_tokens, "hello")
|
||||
|
||||
def test_no_prefix_space(self):
|
||||
tokenizer = LlamaTokenizerFast.from_pretrained(
|
||||
"huggyllama/llama-7b", legacy=False, from_slow=True, add_prefix_space=False
|
||||
)
|
||||
tokenizer.add_tokens([AddedToken("<REPR_END>", rstrip=True, lstrip=True)], special_tokens=False)
|
||||
|
||||
example_inputs = tokenizer.tokenize("<REPR_END>inform<s>. Hey. .")
|
||||
self.assertEqual(example_inputs, ["<REPR_END>", "in", "form", "<s>", ".", "▁Hey", ".", "▁▁▁▁▁▁", "▁."])
|
||||
|
||||
# Make sure dummy space is added if it is indeed the first word
|
||||
example_inputs = tokenizer.tokenize("inform<s>. Hey. .")
|
||||
self.assertEqual(example_inputs, ["in", "form", "<s>", ".", "▁Hey", ".", "▁▁▁▁▁▁", "▁."])
|
||||
out1 = tokenizer.decode(
|
||||
tokenizer.encode("<REPR_END>inform", add_special_tokens=False), spaces_between_special_tokens=False
|
||||
)
|
||||
self.assertEqual(out1, "<REPR_END>inform")
|
||||
out2 = tokenizer.decode(
|
||||
tokenizer.encode("<REPR_END>inform", add_special_tokens=False), spaces_between_special_tokens=True
|
||||
)
|
||||
# decoding strips the added prefix space.
|
||||
self.assertEqual(out2, "<REPR_END>inform")
|
||||
input_ids = tokenizer.encode("<REPR_END>inform", add_special_tokens=False)
|
||||
self.assertEqual(input_ids, [32000, 262, 689]) # 29871 is the spiece underline, '▁' added as it should
|
||||
|
||||
out2 = tokenizer.decode(
|
||||
tokenizer.encode(" <REPR_END>inform", add_special_tokens=False), spaces_between_special_tokens=False
|
||||
)
|
||||
self.assertEqual(out2, "<REPR_END>inform")
|
||||
|
||||
input_ids = tokenizer.encode("<s> Hello<s>how", add_special_tokens=False)
|
||||
self.assertEqual(input_ids, [1, 15043, 1, 3525])
|
||||
tokens = tokenizer.tokenize("<s> Hello<s>how", add_special_tokens=False)
|
||||
self.assertEqual(tokens, ["<s>", "▁Hello", "<s>", "how"])
|
||||
decoded_tokens = tokenizer.decode(input_ids)
|
||||
self.assertEqual(decoded_tokens, "<s> Hello<s>how")
|
||||
|
||||
# Let's make sure that if there are any spaces, we don't remove them!
|
||||
input_ids = tokenizer.encode(" <s> Hello<s> how", add_special_tokens=False)
|
||||
self.assertEqual(input_ids, [29871, 1, 15043, 1, 920])
|
||||
tokens = tokenizer.tokenize(" <s> Hello<s> how", add_special_tokens=False)
|
||||
self.assertEqual(tokens, ["▁", "<s>", "▁Hello", "<s>", "▁how"])
|
||||
decoded_tokens = tokenizer.decode(input_ids)
|
||||
self.assertEqual(decoded_tokens, " <s> Hello<s> how")
|
||||
|
||||
# Let's make sure the space is preserved
|
||||
input_ids = tokenizer.encode("hello", add_special_tokens=True)
|
||||
self.assertEqual(input_ids, [1, 12199])
|
||||
tokens = tokenizer.tokenize("hello")
|
||||
self.assertEqual(tokens, ["hello"])
|
||||
decoded_tokens = tokenizer.decode(input_ids)
|
||||
self.assertEqual(decoded_tokens, "<s>hello")
|
||||
|
||||
input_ids = tokenizer.encode("hello", add_special_tokens=False)
|
||||
self.assertEqual(input_ids, [12199])
|
||||
decoded_tokens = tokenizer.decode(input_ids)
|
||||
self.assertEqual(decoded_tokens, "hello")
|
||||
|
||||
def test_some_edge_cases(self):
|
||||
tokenizer = LlamaTokenizer.from_pretrained("huggyllama/llama-7b", legacy=False)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user