diff --git a/src/transformers/tokenization_bert.py b/src/transformers/tokenization_bert.py index ac14370a18..cefb917798 100644 --- a/src/transformers/tokenization_bert.py +++ b/src/transformers/tokenization_bert.py @@ -341,7 +341,7 @@ class BasicTokenizer(object): if self.do_lower_case and token not in never_split: token = token.lower() token = self._run_strip_accents(token) - split_tokens.extend(self._run_split_on_punc(token)) + split_tokens.extend(self._run_split_on_punc(token, never_split)) output_tokens = whitespace_tokenize(" ".join(split_tokens)) return output_tokens diff --git a/tests/test_tokenization_bert.py b/tests/test_tokenization_bert.py index 793bb8fa54..49bb073351 100644 --- a/tests/test_tokenization_bert.py +++ b/tests/test_tokenization_bert.py @@ -119,6 +119,13 @@ class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase): tokenizer.tokenize(" \tHeLLo!how \n Are yoU? "), ["HeLLo", "!", "how", "Are", "yoU", "?"] ) + def test_basic_tokenizer_respects_never_split_tokens(self): + tokenizer = BasicTokenizer(do_lower_case=False, never_split=["[UNK]"]) + + self.assertListEqual( + tokenizer.tokenize(" \tHeLLo!how \n Are yoU? [UNK]"), ["HeLLo", "!", "how", "Are", "yoU", "?", "[UNK]"] + ) + def test_wordpiece_tokenizer(self): vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", "##ing"]