fix backend tokenizer args override: key mismatch (#10686)
* fix backend tokenizer args override: key mismatch * no touching the docs * fix mpnet * add mpnet to test * fix test Co-authored-by: theo <theo@matussie.re>
This commit is contained in:
@@ -190,11 +190,11 @@ class BertTokenizerFast(PreTrainedTokenizerFast):
|
|||||||
|
|
||||||
pre_tok_state = json.loads(self.backend_tokenizer.normalizer.__getstate__())
|
pre_tok_state = json.loads(self.backend_tokenizer.normalizer.__getstate__())
|
||||||
if (
|
if (
|
||||||
pre_tok_state.get("do_lower_case", do_lower_case) != do_lower_case
|
pre_tok_state.get("lowercase", do_lower_case) != do_lower_case
|
||||||
or pre_tok_state.get("strip_accents", strip_accents) != strip_accents
|
or pre_tok_state.get("strip_accents", strip_accents) != strip_accents
|
||||||
):
|
):
|
||||||
pre_tok_class = getattr(normalizers, pre_tok_state.pop("type"))
|
pre_tok_class = getattr(normalizers, pre_tok_state.pop("type"))
|
||||||
pre_tok_state["do_lower_case"] = do_lower_case
|
pre_tok_state["lowercase"] = do_lower_case
|
||||||
pre_tok_state["strip_accents"] = strip_accents
|
pre_tok_state["strip_accents"] = strip_accents
|
||||||
self.backend_tokenizer.normalizer = pre_tok_class(**pre_tok_state)
|
self.backend_tokenizer.normalizer = pre_tok_class(**pre_tok_state)
|
||||||
|
|
||||||
|
|||||||
@@ -138,11 +138,11 @@ class MPNetTokenizerFast(PreTrainedTokenizerFast):
|
|||||||
|
|
||||||
pre_tok_state = json.loads(self.backend_tokenizer.normalizer.__getstate__())
|
pre_tok_state = json.loads(self.backend_tokenizer.normalizer.__getstate__())
|
||||||
if (
|
if (
|
||||||
pre_tok_state.get("do_lower_case", do_lower_case) != do_lower_case
|
pre_tok_state.get("lowercase", do_lower_case) != do_lower_case
|
||||||
or pre_tok_state.get("strip_accents", strip_accents) != strip_accents
|
or pre_tok_state.get("strip_accents", strip_accents) != strip_accents
|
||||||
):
|
):
|
||||||
pre_tok_class = getattr(normalizers, pre_tok_state.pop("type"))
|
pre_tok_class = getattr(normalizers, pre_tok_state.pop("type"))
|
||||||
pre_tok_state["do_lower_case"] = do_lower_case
|
pre_tok_state["lowercase"] = do_lower_case
|
||||||
pre_tok_state["strip_accents"] = strip_accents
|
pre_tok_state["strip_accents"] = strip_accents
|
||||||
self.backend_tokenizer.normalizer = pre_tok_class(**pre_tok_state)
|
self.backend_tokenizer.normalizer = pre_tok_class(**pre_tok_state)
|
||||||
|
|
||||||
|
|||||||
@@ -110,3 +110,14 @@ class AutoTokenizerTest(unittest.TestCase):
|
|||||||
def test_from_pretrained_use_fast_toggle(self):
|
def test_from_pretrained_use_fast_toggle(self):
|
||||||
self.assertIsInstance(AutoTokenizer.from_pretrained("bert-base-cased", use_fast=False), BertTokenizer)
|
self.assertIsInstance(AutoTokenizer.from_pretrained("bert-base-cased", use_fast=False), BertTokenizer)
|
||||||
self.assertIsInstance(AutoTokenizer.from_pretrained("bert-base-cased"), BertTokenizerFast)
|
self.assertIsInstance(AutoTokenizer.from_pretrained("bert-base-cased"), BertTokenizerFast)
|
||||||
|
|
||||||
|
@require_tokenizers
|
||||||
|
def test_do_lower_case(self):
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased", do_lower_case=False)
|
||||||
|
sample = "Hello, world. How are you?"
|
||||||
|
tokens = tokenizer.tokenize(sample)
|
||||||
|
self.assertEqual("[UNK]", tokens[0])
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("microsoft/mpnet-base", do_lower_case=False)
|
||||||
|
tokens = tokenizer.tokenize(sample)
|
||||||
|
self.assertEqual("[UNK]", tokens[0])
|
||||||
|
|||||||
Reference in New Issue
Block a user