From 82b09a84819f02cf90cd228aad498795eea2f099 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 13 Oct 2020 14:34:22 +0200 Subject: [PATCH] [Rag] Fix loading of pretrained Rag Tokenizer (#7756) * fix rag * Update tokenizer save_pretrained Co-authored-by: Thomas Wolf --- src/transformers/tokenization_utils_base.py | 20 +++++---- tests/test_tokenization_rag.py | 48 ++++++++++++++++++++- 2 files changed, 59 insertions(+), 9 deletions(-) diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index 379266d8bf..d459befbcf 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -1637,9 +1637,11 @@ class PreTrainedTokenizerBase(SpecialTokensMixin): if special_tokens_map_file is not None: with open(special_tokens_map_file, encoding="utf-8") as special_tokens_map_handle: special_tokens_map = json.load(special_tokens_map_handle) - - special_tokens_map = convert_added_tokens(special_tokens_map) for key, value in special_tokens_map.items(): + if isinstance(value, dict): + value = AddedToken(**value) + elif isinstance(value, list): + value = [AddedToken(**token) if isinstance(token, dict) else token for token in value] setattr(tokenizer, key, value) # Add supplementary tokens. @@ -1706,23 +1708,25 @@ class PreTrainedTokenizerBase(SpecialTokensMixin): tokenizer_config.pop(file_id, None) # Sanitize AddedTokens - def convert_added_tokens(obj: Union[AddedToken, Any]): + def convert_added_tokens(obj: Union[AddedToken, Any], add_type_field=True): if isinstance(obj, AddedToken): out = obj.__getstate__() - out["__type"] = "AddedToken" + if add_type_field: + out["__type"] = "AddedToken" return out elif isinstance(obj, (list, tuple)): - return list(convert_added_tokens(o) for o in obj) + return list(convert_added_tokens(o, add_type_field=add_type_field) for o in obj) elif isinstance(obj, dict): - return {k: convert_added_tokens(v) for k, v in obj.items()} + return {k: convert_added_tokens(v, add_type_field=add_type_field) for k, v in obj.items()} return obj - tokenizer_config = convert_added_tokens(tokenizer_config) + # add_type_field=True to allow dicts in the kwargs / differentiate from AddedToken serialization + tokenizer_config = convert_added_tokens(tokenizer_config, add_type_field=True) with open(tokenizer_config_file, "w", encoding="utf-8") as f: f.write(json.dumps(tokenizer_config, ensure_ascii=False)) # Sanitize AddedTokens in special_tokens_map - write_dict = convert_added_tokens(self.special_tokens_map_extended) + write_dict = convert_added_tokens(self.special_tokens_map_extended, add_type_field=False) with open(special_tokens_map_file, "w", encoding="utf-8") as f: f.write(json.dumps(write_dict, ensure_ascii=False)) diff --git a/tests/test_tokenization_rag.py b/tests/test_tokenization_rag.py index ab18af4902..158aadca69 100644 --- a/tests/test_tokenization_rag.py +++ b/tests/test_tokenization_rag.py @@ -7,7 +7,7 @@ from unittest import TestCase from transformers.configuration_bart import BartConfig from transformers.configuration_dpr import DPRConfig from transformers.file_utils import is_datasets_available, is_faiss_available, is_torch_available -from transformers.testing_utils import require_datasets, require_faiss, require_torch +from transformers.testing_utils import require_datasets, require_faiss, require_torch, slow from transformers.tokenization_bart import BartTokenizer from transformers.tokenization_bert import VOCAB_FILES_NAMES as DPR_VOCAB_FILES_NAMES from transformers.tokenization_dpr import DPRQuestionEncoderTokenizer @@ -108,3 +108,49 @@ class RagTokenizerTest(TestCase): self.assertEqual(new_rag_tokenizer.question_encoder.vocab, rag_tokenizer.question_encoder.vocab) self.assertIsInstance(new_rag_tokenizer.generator, BartTokenizer) self.assertEqual(new_rag_tokenizer.generator.encoder, rag_tokenizer.generator.encoder) + + @slow + def test_pretrained_token_nq_tokenizer(self): + tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq") + input_strings = [ + "who got the first nobel prize in physics", + "when is the next deadpool movie being released", + "which mode is used for short wave broadcast service", + "who is the owner of reading football club", + "when is the next scandal episode coming out", + "when is the last time the philadelphia won the superbowl", + "what is the most current adobe flash player version", + "how many episodes are there in dragon ball z", + "what is the first step in the evolution of the eye", + "where is gall bladder situated in human body", + "what is the main mineral in lithium batteries", + "who is the president of usa right now", + "where do the greasers live in the outsiders", + "panda is a national animal of which country", + "what is the name of manchester united stadium", + ] + input_dict = tokenizer(input_strings) + self.assertIsNotNone(input_dict) + + @slow + def test_pretrained_sequence_nq_tokenizer(self): + tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq") + input_strings = [ + "who got the first nobel prize in physics", + "when is the next deadpool movie being released", + "which mode is used for short wave broadcast service", + "who is the owner of reading football club", + "when is the next scandal episode coming out", + "when is the last time the philadelphia won the superbowl", + "what is the most current adobe flash player version", + "how many episodes are there in dragon ball z", + "what is the first step in the evolution of the eye", + "where is gall bladder situated in human body", + "what is the main mineral in lithium batteries", + "who is the president of usa right now", + "where do the greasers live in the outsiders", + "panda is a national animal of which country", + "what is the name of manchester united stadium", + ] + input_dict = tokenizer(input_strings) + self.assertIsNotNone(input_dict)