fix BertTokenizerFast tokenize_chinese_chars arg (#15158)
* add new test * fix in init * more relevant test
This commit is contained in:
@@ -188,15 +188,17 @@ class BertTokenizerFast(PreTrainedTokenizerFast):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
pre_tok_state = json.loads(self.backend_tokenizer.normalizer.__getstate__())
|
normalizer_state = json.loads(self.backend_tokenizer.normalizer.__getstate__())
|
||||||
if (
|
if (
|
||||||
pre_tok_state.get("lowercase", do_lower_case) != do_lower_case
|
normalizer_state.get("lowercase", do_lower_case) != do_lower_case
|
||||||
or pre_tok_state.get("strip_accents", strip_accents) != strip_accents
|
or normalizer_state.get("strip_accents", strip_accents) != strip_accents
|
||||||
|
or normalizer_state.get("handle_chinese_chars", tokenize_chinese_chars) != tokenize_chinese_chars
|
||||||
):
|
):
|
||||||
pre_tok_class = getattr(normalizers, pre_tok_state.pop("type"))
|
normalizer_class = getattr(normalizers, normalizer_state.pop("type"))
|
||||||
pre_tok_state["lowercase"] = do_lower_case
|
normalizer_state["lowercase"] = do_lower_case
|
||||||
pre_tok_state["strip_accents"] = strip_accents
|
normalizer_state["strip_accents"] = strip_accents
|
||||||
self.backend_tokenizer.normalizer = pre_tok_class(**pre_tok_state)
|
normalizer_state["handle_chinese_chars"] = tokenize_chinese_chars
|
||||||
|
self.backend_tokenizer.normalizer = normalizer_class(**normalizer_state)
|
||||||
|
|
||||||
self.do_lower_case = do_lower_case
|
self.do_lower_case = do_lower_case
|
||||||
|
|
||||||
|
|||||||
@@ -299,3 +299,40 @@ class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
[e[1] for e in expected_results], tokenizer_r.convert_ids_to_tokens(tokens["input_ids"])
|
[e[1] for e in expected_results], tokenizer_r.convert_ids_to_tokens(tokens["input_ids"])
|
||||||
)
|
)
|
||||||
self.assertEqual([e[0] for e in expected_results], tokens["offset_mapping"])
|
self.assertEqual([e[0] for e in expected_results], tokens["offset_mapping"])
|
||||||
|
|
||||||
|
def test_change_tokenize_chinese_chars(self):
|
||||||
|
list_of_commun_chinese_char = ["的", "人", "有"]
|
||||||
|
text_with_chinese_char = "".join(list_of_commun_chinese_char)
|
||||||
|
for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
|
||||||
|
with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
|
||||||
|
|
||||||
|
kwargs["tokenize_chinese_chars"] = True
|
||||||
|
tokenizer_p = self.tokenizer_class.from_pretrained(pretrained_name, **kwargs)
|
||||||
|
tokenizer_r = self.rust_tokenizer_class.from_pretrained(pretrained_name, **kwargs)
|
||||||
|
|
||||||
|
ids_without_spe_char_p = tokenizer_p.encode(text_with_chinese_char, add_special_tokens=False)
|
||||||
|
ids_without_spe_char_r = tokenizer_r.encode(text_with_chinese_char, add_special_tokens=False)
|
||||||
|
|
||||||
|
tokens_without_spe_char_r = tokenizer_r.convert_ids_to_tokens(ids_without_spe_char_r)
|
||||||
|
tokens_without_spe_char_p = tokenizer_p.convert_ids_to_tokens(ids_without_spe_char_p)
|
||||||
|
|
||||||
|
# it is expected that each Chinese character is not preceded by "##"
|
||||||
|
self.assertListEqual(tokens_without_spe_char_p, list_of_commun_chinese_char)
|
||||||
|
self.assertListEqual(tokens_without_spe_char_r, list_of_commun_chinese_char)
|
||||||
|
|
||||||
|
kwargs["tokenize_chinese_chars"] = False
|
||||||
|
tokenizer_r = self.rust_tokenizer_class.from_pretrained(pretrained_name, **kwargs)
|
||||||
|
tokenizer_p = self.tokenizer_class.from_pretrained(pretrained_name, **kwargs)
|
||||||
|
|
||||||
|
ids_without_spe_char_r = tokenizer_r.encode(text_with_chinese_char, add_special_tokens=False)
|
||||||
|
ids_without_spe_char_p = tokenizer_p.encode(text_with_chinese_char, add_special_tokens=False)
|
||||||
|
|
||||||
|
tokens_without_spe_char_r = tokenizer_r.convert_ids_to_tokens(ids_without_spe_char_r)
|
||||||
|
tokens_without_spe_char_p = tokenizer_p.convert_ids_to_tokens(ids_without_spe_char_p)
|
||||||
|
|
||||||
|
# it is expected that only the first Chinese character is not preceded by "##".
|
||||||
|
expected_tokens = [
|
||||||
|
f"##{token}" if idx != 0 else token for idx, token in enumerate(list_of_commun_chinese_char)
|
||||||
|
]
|
||||||
|
self.assertListEqual(tokens_without_spe_char_p, expected_tokens)
|
||||||
|
self.assertListEqual(tokens_without_spe_char_r, expected_tokens)
|
||||||
|
|||||||
Reference in New Issue
Block a user