Fix idx2sym not loaded from pretrained vocab file in Transformer XL (#27589)
* Load idx2sym from pretrained vocab file in Transformer XL When loading vocab file from a pretrained tokenizer for Transformer XL, although the pickled vocabulary file contains a idx2sym key, it isn't loaded, because it is discarded as the empty list already exists as an attribute. Solution is to explicitly take it into account, just like for sym2idx. * ran make style
This commit is contained in:
@@ -15,7 +15,9 @@
|
||||
|
||||
|
||||
import os
|
||||
import pickle
|
||||
import unittest
|
||||
from collections import Counter, OrderedDict
|
||||
|
||||
from transformers.models.transfo_xl.tokenization_transfo_xl import VOCAB_FILES_NAMES, TransfoXLTokenizer
|
||||
|
||||
@@ -47,6 +49,25 @@ class TransfoXLTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer:
|
||||
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
|
||||
|
||||
saved_dict = {
|
||||
"eos_idx": 0,
|
||||
"min_freq": 0,
|
||||
"vocab_file": None,
|
||||
"counter": Counter(["welcome home"]),
|
||||
"sym2idx": OrderedDict([("<eos>", 0), ("welcome", 1), ("home", 2)]),
|
||||
"delimiter": None,
|
||||
"idx2sym": ["<eos>", "welcome", "home"],
|
||||
"max_size": None,
|
||||
"lower_case": False,
|
||||
"special": ["<eos>"],
|
||||
}
|
||||
self.pretrained_vocab_file = os.path.join(
|
||||
self.tmpdirname, "mock_folder", VOCAB_FILES_NAMES["pretrained_vocab_file"]
|
||||
)
|
||||
os.makedirs(os.path.dirname(self.pretrained_vocab_file), exist_ok=True)
|
||||
with open(self.pretrained_vocab_file, "wb") as f:
|
||||
pickle.dump(saved_dict, f)
|
||||
|
||||
def get_tokenizer(self, **kwargs):
|
||||
kwargs["lower_case"] = True
|
||||
return TransfoXLTokenizer.from_pretrained(self.tmpdirname, **kwargs)
|
||||
@@ -128,3 +149,8 @@ class TransfoXLTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
# Check that token is moved to specified id
|
||||
self.assertEqual(tokenizer.encode("new1"), [1])
|
||||
self.assertEqual(tokenizer.decode([1]), "new1")
|
||||
|
||||
def test_from_pretrained_vocab_file(self):
|
||||
tokenizer = TransfoXLTokenizer.from_pretrained(os.path.join(self.tmpdirname, "mock_folder"))
|
||||
sentence = "welcome home"
|
||||
self.assertEqual(tokenizer.decode(tokenizer.encode(sentence)), sentence)
|
||||
|
||||
Reference in New Issue
Block a user