Fix idx2sym not loaded from pretrained vocab file in Transformer XL (#27589)

* Load idx2sym from pretrained vocab file in Transformer XL

When loading vocab file from a pretrained tokenizer for Transformer XL,
although the pickled vocabulary file contains a idx2sym key, it isn't
loaded, because it is discarded as the empty list already exists as
an attribute.

Solution is to explicitly take it into account, just like for sym2idx.

* ran make style
This commit is contained in:
Joel Tang
2023-11-20 07:56:18 +01:00
committed by GitHub
parent dc68a39c81
commit dbf7bfafa7
2 changed files with 27 additions and 1 deletions

View File

@@ -15,7 +15,9 @@
import os
import pickle
import unittest
from collections import Counter, OrderedDict
from transformers.models.transfo_xl.tokenization_transfo_xl import VOCAB_FILES_NAMES, TransfoXLTokenizer
@@ -47,6 +49,25 @@ class TransfoXLTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
saved_dict = {
"eos_idx": 0,
"min_freq": 0,
"vocab_file": None,
"counter": Counter(["welcome home"]),
"sym2idx": OrderedDict([("<eos>", 0), ("welcome", 1), ("home", 2)]),
"delimiter": None,
"idx2sym": ["<eos>", "welcome", "home"],
"max_size": None,
"lower_case": False,
"special": ["<eos>"],
}
self.pretrained_vocab_file = os.path.join(
self.tmpdirname, "mock_folder", VOCAB_FILES_NAMES["pretrained_vocab_file"]
)
os.makedirs(os.path.dirname(self.pretrained_vocab_file), exist_ok=True)
with open(self.pretrained_vocab_file, "wb") as f:
pickle.dump(saved_dict, f)
def get_tokenizer(self, **kwargs):
kwargs["lower_case"] = True
return TransfoXLTokenizer.from_pretrained(self.tmpdirname, **kwargs)
@@ -128,3 +149,8 @@ class TransfoXLTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
# Check that token is moved to specified id
self.assertEqual(tokenizer.encode("new1"), [1])
self.assertEqual(tokenizer.decode([1]), "new1")
def test_from_pretrained_vocab_file(self):
tokenizer = TransfoXLTokenizer.from_pretrained(os.path.join(self.tmpdirname, "mock_folder"))
sentence = "welcome home"
self.assertEqual(tokenizer.decode(tokenizer.encode(sentence)), sentence)