From dbf7bfafa7d9a0e5d7963c5d15350ea6b34060ab Mon Sep 17 00:00:00 2001 From: Joel Tang <44188317+jtang98@users.noreply.github.com> Date: Mon, 20 Nov 2023 07:56:18 +0100 Subject: [PATCH] Fix idx2sym not loaded from pretrained vocab file in Transformer XL (#27589) * Load idx2sym from pretrained vocab file in Transformer XL When loading vocab file from a pretrained tokenizer for Transformer XL, although the pickled vocabulary file contains a idx2sym key, it isn't loaded, because it is discarded as the empty list already exists as an attribute. Solution is to explicitly take it into account, just like for sym2idx. * ran make style --- .../transfo_xl/tokenization_transfo_xl.py | 2 +- .../test_tokenization_transfo_xl.py | 26 +++++++++++++++++++ 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/transfo_xl/tokenization_transfo_xl.py b/src/transformers/models/transfo_xl/tokenization_transfo_xl.py index 8a2aba92f7..eaa5ecee4b 100644 --- a/src/transformers/models/transfo_xl/tokenization_transfo_xl.py +++ b/src/transformers/models/transfo_xl/tokenization_transfo_xl.py @@ -223,7 +223,7 @@ class TransfoXLTokenizer(PreTrainedTokenizer): if vocab_dict is not None: for key, value in vocab_dict.items(): - if key not in self.__dict__ or key == "sym2idx": + if key not in self.__dict__ or key in ["sym2idx", "idx2sym"]: self.__dict__[key] = value elif vocab_file is not None: self.build_vocab() diff --git a/tests/models/transfo_xl/test_tokenization_transfo_xl.py b/tests/models/transfo_xl/test_tokenization_transfo_xl.py index 15b712ff37..d8835a164c 100644 --- a/tests/models/transfo_xl/test_tokenization_transfo_xl.py +++ b/tests/models/transfo_xl/test_tokenization_transfo_xl.py @@ -15,7 +15,9 @@ import os +import pickle import unittest +from collections import Counter, OrderedDict from transformers.models.transfo_xl.tokenization_transfo_xl import VOCAB_FILES_NAMES, TransfoXLTokenizer @@ -47,6 +49,25 @@ class TransfoXLTokenizationTest(TokenizerTesterMixin, unittest.TestCase): with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer: vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) + saved_dict = { + "eos_idx": 0, + "min_freq": 0, + "vocab_file": None, + "counter": Counter(["welcome home"]), + "sym2idx": OrderedDict([("", 0), ("welcome", 1), ("home", 2)]), + "delimiter": None, + "idx2sym": ["", "welcome", "home"], + "max_size": None, + "lower_case": False, + "special": [""], + } + self.pretrained_vocab_file = os.path.join( + self.tmpdirname, "mock_folder", VOCAB_FILES_NAMES["pretrained_vocab_file"] + ) + os.makedirs(os.path.dirname(self.pretrained_vocab_file), exist_ok=True) + with open(self.pretrained_vocab_file, "wb") as f: + pickle.dump(saved_dict, f) + def get_tokenizer(self, **kwargs): kwargs["lower_case"] = True return TransfoXLTokenizer.from_pretrained(self.tmpdirname, **kwargs) @@ -128,3 +149,8 @@ class TransfoXLTokenizationTest(TokenizerTesterMixin, unittest.TestCase): # Check that token is moved to specified id self.assertEqual(tokenizer.encode("new1"), [1]) self.assertEqual(tokenizer.decode([1]), "new1") + + def test_from_pretrained_vocab_file(self): + tokenizer = TransfoXLTokenizer.from_pretrained(os.path.join(self.tmpdirname, "mock_folder")) + sentence = "welcome home" + self.assertEqual(tokenizer.decode(tokenizer.encode(sentence)), sentence)