From 51d7ebf260104cdd10f4c8fee295f9dad53775a5 Mon Sep 17 00:00:00 2001 From: SaulLu <55560583+SaulLu@users.noreply.github.com> Date: Fri, 14 Jan 2022 14:22:03 +0100 Subject: [PATCH] fix BertTokenizerFast `tokenize_chinese_chars` arg (#15158) * add new test * fix in init * more relevant test --- .../models/bert/tokenization_bert_fast.py | 16 ++++---- tests/test_tokenization_bert.py | 37 +++++++++++++++++++ 2 files changed, 46 insertions(+), 7 deletions(-) diff --git a/src/transformers/models/bert/tokenization_bert_fast.py b/src/transformers/models/bert/tokenization_bert_fast.py index 4fd53be98d..111deb27af 100644 --- a/src/transformers/models/bert/tokenization_bert_fast.py +++ b/src/transformers/models/bert/tokenization_bert_fast.py @@ -188,15 +188,17 @@ class BertTokenizerFast(PreTrainedTokenizerFast): **kwargs, ) - pre_tok_state = json.loads(self.backend_tokenizer.normalizer.__getstate__()) + normalizer_state = json.loads(self.backend_tokenizer.normalizer.__getstate__()) if ( - pre_tok_state.get("lowercase", do_lower_case) != do_lower_case - or pre_tok_state.get("strip_accents", strip_accents) != strip_accents + normalizer_state.get("lowercase", do_lower_case) != do_lower_case + or normalizer_state.get("strip_accents", strip_accents) != strip_accents + or normalizer_state.get("handle_chinese_chars", tokenize_chinese_chars) != tokenize_chinese_chars ): - pre_tok_class = getattr(normalizers, pre_tok_state.pop("type")) - pre_tok_state["lowercase"] = do_lower_case - pre_tok_state["strip_accents"] = strip_accents - self.backend_tokenizer.normalizer = pre_tok_class(**pre_tok_state) + normalizer_class = getattr(normalizers, normalizer_state.pop("type")) + normalizer_state["lowercase"] = do_lower_case + normalizer_state["strip_accents"] = strip_accents + normalizer_state["handle_chinese_chars"] = tokenize_chinese_chars + self.backend_tokenizer.normalizer = normalizer_class(**normalizer_state) self.do_lower_case = do_lower_case diff --git a/tests/test_tokenization_bert.py b/tests/test_tokenization_bert.py index 3b8dced0ab..53f8ef6eb0 100644 --- a/tests/test_tokenization_bert.py +++ b/tests/test_tokenization_bert.py @@ -299,3 +299,40 @@ class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase): [e[1] for e in expected_results], tokenizer_r.convert_ids_to_tokens(tokens["input_ids"]) ) self.assertEqual([e[0] for e in expected_results], tokens["offset_mapping"]) + + def test_change_tokenize_chinese_chars(self): + list_of_commun_chinese_char = ["的", "人", "有"] + text_with_chinese_char = "".join(list_of_commun_chinese_char) + for tokenizer, pretrained_name, kwargs in self.tokenizers_list: + with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"): + + kwargs["tokenize_chinese_chars"] = True + tokenizer_p = self.tokenizer_class.from_pretrained(pretrained_name, **kwargs) + tokenizer_r = self.rust_tokenizer_class.from_pretrained(pretrained_name, **kwargs) + + ids_without_spe_char_p = tokenizer_p.encode(text_with_chinese_char, add_special_tokens=False) + ids_without_spe_char_r = tokenizer_r.encode(text_with_chinese_char, add_special_tokens=False) + + tokens_without_spe_char_r = tokenizer_r.convert_ids_to_tokens(ids_without_spe_char_r) + tokens_without_spe_char_p = tokenizer_p.convert_ids_to_tokens(ids_without_spe_char_p) + + # it is expected that each Chinese character is not preceded by "##" + self.assertListEqual(tokens_without_spe_char_p, list_of_commun_chinese_char) + self.assertListEqual(tokens_without_spe_char_r, list_of_commun_chinese_char) + + kwargs["tokenize_chinese_chars"] = False + tokenizer_r = self.rust_tokenizer_class.from_pretrained(pretrained_name, **kwargs) + tokenizer_p = self.tokenizer_class.from_pretrained(pretrained_name, **kwargs) + + ids_without_spe_char_r = tokenizer_r.encode(text_with_chinese_char, add_special_tokens=False) + ids_without_spe_char_p = tokenizer_p.encode(text_with_chinese_char, add_special_tokens=False) + + tokens_without_spe_char_r = tokenizer_r.convert_ids_to_tokens(ids_without_spe_char_r) + tokens_without_spe_char_p = tokenizer_p.convert_ids_to_tokens(ids_without_spe_char_p) + + # it is expected that only the first Chinese character is not preceded by "##". + expected_tokens = [ + f"##{token}" if idx != 0 else token for idx, token in enumerate(list_of_commun_chinese_char) + ] + self.assertListEqual(tokens_without_spe_char_p, expected_tokens) + self.assertListEqual(tokens_without_spe_char_r, expected_tokens)