[ TokenizationLlama] fix the way we convert tokens to strings to keep leading spaces 🚨 breaking fix (#29453)
* nit * update test and fix test * fixup
This commit is contained in:
@@ -295,6 +295,8 @@ class LlamaTokenizer(PreTrainedTokenizer):
|
|||||||
prev_is_special = True
|
prev_is_special = True
|
||||||
current_sub_tokens = []
|
current_sub_tokens = []
|
||||||
else:
|
else:
|
||||||
|
if prev_is_special and i == 1 and self.add_prefix_space and not token.startswith(SPIECE_UNDERLINE):
|
||||||
|
out_string += " "
|
||||||
current_sub_tokens.append(token)
|
current_sub_tokens.append(token)
|
||||||
prev_is_special = False
|
prev_is_special = False
|
||||||
out_string += self.sp_model.decode(current_sub_tokens)
|
out_string += self.sp_model.decode(current_sub_tokens)
|
||||||
|
|||||||
@@ -581,6 +581,19 @@ class LlamaIntegrationTest(unittest.TestCase):
|
|||||||
decoded_tokens = tokenizer.decode(input_ids)
|
decoded_tokens = tokenizer.decode(input_ids)
|
||||||
self.assertEqual(decoded_tokens, " <s> Hello<s> how")
|
self.assertEqual(decoded_tokens, " <s> Hello<s> how")
|
||||||
|
|
||||||
|
# Let's make sure the space is preserved
|
||||||
|
input_ids = tokenizer.encode("hello", add_special_tokens=True)
|
||||||
|
self.assertEqual(input_ids, [1, 22172])
|
||||||
|
tokens = tokenizer.tokenize("hello")
|
||||||
|
self.assertEqual(tokens, ["▁hello"])
|
||||||
|
decoded_tokens = tokenizer.decode(input_ids)
|
||||||
|
self.assertEqual(decoded_tokens, "<s> hello")
|
||||||
|
|
||||||
|
input_ids = tokenizer.encode("hello", add_special_tokens=False)
|
||||||
|
self.assertEqual(input_ids, [22172])
|
||||||
|
decoded_tokens = tokenizer.decode(input_ids)
|
||||||
|
self.assertEqual(decoded_tokens, "hello")
|
||||||
|
|
||||||
def test_some_edge_cases(self):
|
def test_some_edge_cases(self):
|
||||||
tokenizer = LlamaTokenizer.from_pretrained("huggyllama/llama-7b", legacy=False)
|
tokenizer = LlamaTokenizer.from_pretrained("huggyllama/llama-7b", legacy=False)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user