fix tokenizers serialization
This commit is contained in:
@@ -27,8 +27,8 @@ class DistilBertTokenizationTest(BertTokenizationTest):
|
|||||||
|
|
||||||
tokenizer_class = DistilBertTokenizer
|
tokenizer_class = DistilBertTokenizer
|
||||||
|
|
||||||
def get_tokenizer(self):
|
def get_tokenizer(self, **kwargs):
|
||||||
return DistilBertTokenizer.from_pretrained(self.tmpdirname)
|
return DistilBertTokenizer.from_pretrained(self.tmpdirname, **kwargs)
|
||||||
|
|
||||||
def test_sequence_builders(self):
|
def test_sequence_builders(self):
|
||||||
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
|
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
|
||||||
|
|||||||
@@ -67,13 +67,13 @@ class CommonTestCases:
|
|||||||
|
|
||||||
with TemporaryDirectory() as tmpdirname:
|
with TemporaryDirectory() as tmpdirname:
|
||||||
tokenizer.save_pretrained(tmpdirname)
|
tokenizer.save_pretrained(tmpdirname)
|
||||||
tokenizer = tokenizer.from_pretrained(tmpdirname)
|
tokenizer = self.tokenizer_class.from_pretrained(tmpdirname)
|
||||||
|
|
||||||
after_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running")
|
after_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running")
|
||||||
self.assertListEqual(before_tokens, after_tokens)
|
self.assertListEqual(before_tokens, after_tokens)
|
||||||
|
|
||||||
self.assertEqual(tokenizer.max_len, 42)
|
self.assertEqual(tokenizer.max_len, 42)
|
||||||
tokenizer = tokenizer.from_pretrained(tmpdirname, max_len=43)
|
tokenizer = self.tokenizer_class.from_pretrained(tmpdirname, max_len=43)
|
||||||
self.assertEqual(tokenizer.max_len, 43)
|
self.assertEqual(tokenizer.max_len, 43)
|
||||||
|
|
||||||
def test_pickle_tokenizer(self):
|
def test_pickle_tokenizer(self):
|
||||||
|
|||||||
@@ -95,6 +95,7 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
|
|||||||
# in a library like ours, at all.
|
# in a library like ours, at all.
|
||||||
vocab_dict = torch.load(pretrained_vocab_file)
|
vocab_dict = torch.load(pretrained_vocab_file)
|
||||||
for key, value in vocab_dict.items():
|
for key, value in vocab_dict.items():
|
||||||
|
if key not in self.__dict__:
|
||||||
self.__dict__[key] = value
|
self.__dict__[key] = value
|
||||||
|
|
||||||
if vocab_file is not None:
|
if vocab_file is not None:
|
||||||
|
|||||||
@@ -61,7 +61,7 @@ class XLNetTokenizer(PreTrainedTokenizer):
|
|||||||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||||
|
|
||||||
def __init__(self, vocab_file, max_len=None,
|
def __init__(self, vocab_file,
|
||||||
do_lower_case=False, remove_space=True, keep_accents=False,
|
do_lower_case=False, remove_space=True, keep_accents=False,
|
||||||
bos_token="<s>", eos_token="</s>", unk_token="<unk>", sep_token="<sep>",
|
bos_token="<s>", eos_token="</s>", unk_token="<unk>", sep_token="<sep>",
|
||||||
pad_token="<pad>", cls_token="<cls>", mask_token="<mask>",
|
pad_token="<pad>", cls_token="<cls>", mask_token="<mask>",
|
||||||
|
|||||||
Reference in New Issue
Block a user