Fix idx2sym not loaded from pretrained vocab file in Transformer XL (#27589)
* Load idx2sym from pretrained vocab file in Transformer XL When loading vocab file from a pretrained tokenizer for Transformer XL, although the pickled vocabulary file contains a idx2sym key, it isn't loaded, because it is discarded as the empty list already exists as an attribute. Solution is to explicitly take it into account, just like for sym2idx. * ran make style
This commit is contained in:
@@ -223,7 +223,7 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
|
|||||||
|
|
||||||
if vocab_dict is not None:
|
if vocab_dict is not None:
|
||||||
for key, value in vocab_dict.items():
|
for key, value in vocab_dict.items():
|
||||||
if key not in self.__dict__ or key == "sym2idx":
|
if key not in self.__dict__ or key in ["sym2idx", "idx2sym"]:
|
||||||
self.__dict__[key] = value
|
self.__dict__[key] = value
|
||||||
elif vocab_file is not None:
|
elif vocab_file is not None:
|
||||||
self.build_vocab()
|
self.build_vocab()
|
||||||
|
|||||||
@@ -15,7 +15,9 @@
|
|||||||
|
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import pickle
|
||||||
import unittest
|
import unittest
|
||||||
|
from collections import Counter, OrderedDict
|
||||||
|
|
||||||
from transformers.models.transfo_xl.tokenization_transfo_xl import VOCAB_FILES_NAMES, TransfoXLTokenizer
|
from transformers.models.transfo_xl.tokenization_transfo_xl import VOCAB_FILES_NAMES, TransfoXLTokenizer
|
||||||
|
|
||||||
@@ -47,6 +49,25 @@ class TransfoXLTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer:
|
with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer:
|
||||||
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
|
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
|
||||||
|
|
||||||
|
saved_dict = {
|
||||||
|
"eos_idx": 0,
|
||||||
|
"min_freq": 0,
|
||||||
|
"vocab_file": None,
|
||||||
|
"counter": Counter(["welcome home"]),
|
||||||
|
"sym2idx": OrderedDict([("<eos>", 0), ("welcome", 1), ("home", 2)]),
|
||||||
|
"delimiter": None,
|
||||||
|
"idx2sym": ["<eos>", "welcome", "home"],
|
||||||
|
"max_size": None,
|
||||||
|
"lower_case": False,
|
||||||
|
"special": ["<eos>"],
|
||||||
|
}
|
||||||
|
self.pretrained_vocab_file = os.path.join(
|
||||||
|
self.tmpdirname, "mock_folder", VOCAB_FILES_NAMES["pretrained_vocab_file"]
|
||||||
|
)
|
||||||
|
os.makedirs(os.path.dirname(self.pretrained_vocab_file), exist_ok=True)
|
||||||
|
with open(self.pretrained_vocab_file, "wb") as f:
|
||||||
|
pickle.dump(saved_dict, f)
|
||||||
|
|
||||||
def get_tokenizer(self, **kwargs):
|
def get_tokenizer(self, **kwargs):
|
||||||
kwargs["lower_case"] = True
|
kwargs["lower_case"] = True
|
||||||
return TransfoXLTokenizer.from_pretrained(self.tmpdirname, **kwargs)
|
return TransfoXLTokenizer.from_pretrained(self.tmpdirname, **kwargs)
|
||||||
@@ -128,3 +149,8 @@ class TransfoXLTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
# Check that token is moved to specified id
|
# Check that token is moved to specified id
|
||||||
self.assertEqual(tokenizer.encode("new1"), [1])
|
self.assertEqual(tokenizer.encode("new1"), [1])
|
||||||
self.assertEqual(tokenizer.decode([1]), "new1")
|
self.assertEqual(tokenizer.decode([1]), "new1")
|
||||||
|
|
||||||
|
def test_from_pretrained_vocab_file(self):
|
||||||
|
tokenizer = TransfoXLTokenizer.from_pretrained(os.path.join(self.tmpdirname, "mock_folder"))
|
||||||
|
sentence = "welcome home"
|
||||||
|
self.assertEqual(tokenizer.decode(tokenizer.encode(sentence)), sentence)
|
||||||
|
|||||||
Reference in New Issue
Block a user