From 117dba99489b9d87467ee787fa53d86415d4eab1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Matussi=C3=A8re?= Date: Fri, 19 Mar 2021 03:13:45 +0100 Subject: [PATCH] fix backend tokenizer args override: key mismatch (#10686) * fix backend tokenizer args override: key mismatch * no touching the docs * fix mpnet * add mpnet to test * fix test Co-authored-by: theo --- .../models/bert/tokenization_bert_fast.py | 4 ++-- .../models/mpnet/tokenization_mpnet_fast.py | 4 ++-- tests/test_tokenization_auto.py | 11 +++++++++++ 3 files changed, 15 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/bert/tokenization_bert_fast.py b/src/transformers/models/bert/tokenization_bert_fast.py index f93446c35f..e477cf7af4 100644 --- a/src/transformers/models/bert/tokenization_bert_fast.py +++ b/src/transformers/models/bert/tokenization_bert_fast.py @@ -190,11 +190,11 @@ class BertTokenizerFast(PreTrainedTokenizerFast): pre_tok_state = json.loads(self.backend_tokenizer.normalizer.__getstate__()) if ( - pre_tok_state.get("do_lower_case", do_lower_case) != do_lower_case + pre_tok_state.get("lowercase", do_lower_case) != do_lower_case or pre_tok_state.get("strip_accents", strip_accents) != strip_accents ): pre_tok_class = getattr(normalizers, pre_tok_state.pop("type")) - pre_tok_state["do_lower_case"] = do_lower_case + pre_tok_state["lowercase"] = do_lower_case pre_tok_state["strip_accents"] = strip_accents self.backend_tokenizer.normalizer = pre_tok_class(**pre_tok_state) diff --git a/src/transformers/models/mpnet/tokenization_mpnet_fast.py b/src/transformers/models/mpnet/tokenization_mpnet_fast.py index 8f35528b96..07547fce57 100644 --- a/src/transformers/models/mpnet/tokenization_mpnet_fast.py +++ b/src/transformers/models/mpnet/tokenization_mpnet_fast.py @@ -138,11 +138,11 @@ class MPNetTokenizerFast(PreTrainedTokenizerFast): pre_tok_state = json.loads(self.backend_tokenizer.normalizer.__getstate__()) if ( - pre_tok_state.get("do_lower_case", do_lower_case) != do_lower_case + pre_tok_state.get("lowercase", do_lower_case) != do_lower_case or pre_tok_state.get("strip_accents", strip_accents) != strip_accents ): pre_tok_class = getattr(normalizers, pre_tok_state.pop("type")) - pre_tok_state["do_lower_case"] = do_lower_case + pre_tok_state["lowercase"] = do_lower_case pre_tok_state["strip_accents"] = strip_accents self.backend_tokenizer.normalizer = pre_tok_class(**pre_tok_state) diff --git a/tests/test_tokenization_auto.py b/tests/test_tokenization_auto.py index 71c5f29f4e..d632cbc558 100644 --- a/tests/test_tokenization_auto.py +++ b/tests/test_tokenization_auto.py @@ -110,3 +110,14 @@ class AutoTokenizerTest(unittest.TestCase): def test_from_pretrained_use_fast_toggle(self): self.assertIsInstance(AutoTokenizer.from_pretrained("bert-base-cased", use_fast=False), BertTokenizer) self.assertIsInstance(AutoTokenizer.from_pretrained("bert-base-cased"), BertTokenizerFast) + + @require_tokenizers + def test_do_lower_case(self): + tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased", do_lower_case=False) + sample = "Hello, world. How are you?" + tokens = tokenizer.tokenize(sample) + self.assertEqual("[UNK]", tokens[0]) + + tokenizer = AutoTokenizer.from_pretrained("microsoft/mpnet-base", do_lower_case=False) + tokens = tokenizer.tokenize(sample) + self.assertEqual("[UNK]", tokens[0])