[Rag] Fix loading of pretrained Rag Tokenizer (#7756)
* fix rag * Update tokenizer save_pretrained Co-authored-by: Thomas Wolf <thomwolf@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
2d4e928d97
commit
82b09a8481
@@ -7,7 +7,7 @@ from unittest import TestCase
|
||||
from transformers.configuration_bart import BartConfig
|
||||
from transformers.configuration_dpr import DPRConfig
|
||||
from transformers.file_utils import is_datasets_available, is_faiss_available, is_torch_available
|
||||
from transformers.testing_utils import require_datasets, require_faiss, require_torch
|
||||
from transformers.testing_utils import require_datasets, require_faiss, require_torch, slow
|
||||
from transformers.tokenization_bart import BartTokenizer
|
||||
from transformers.tokenization_bert import VOCAB_FILES_NAMES as DPR_VOCAB_FILES_NAMES
|
||||
from transformers.tokenization_dpr import DPRQuestionEncoderTokenizer
|
||||
@@ -108,3 +108,49 @@ class RagTokenizerTest(TestCase):
|
||||
self.assertEqual(new_rag_tokenizer.question_encoder.vocab, rag_tokenizer.question_encoder.vocab)
|
||||
self.assertIsInstance(new_rag_tokenizer.generator, BartTokenizer)
|
||||
self.assertEqual(new_rag_tokenizer.generator.encoder, rag_tokenizer.generator.encoder)
|
||||
|
||||
@slow
|
||||
def test_pretrained_token_nq_tokenizer(self):
|
||||
tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")
|
||||
input_strings = [
|
||||
"who got the first nobel prize in physics",
|
||||
"when is the next deadpool movie being released",
|
||||
"which mode is used for short wave broadcast service",
|
||||
"who is the owner of reading football club",
|
||||
"when is the next scandal episode coming out",
|
||||
"when is the last time the philadelphia won the superbowl",
|
||||
"what is the most current adobe flash player version",
|
||||
"how many episodes are there in dragon ball z",
|
||||
"what is the first step in the evolution of the eye",
|
||||
"where is gall bladder situated in human body",
|
||||
"what is the main mineral in lithium batteries",
|
||||
"who is the president of usa right now",
|
||||
"where do the greasers live in the outsiders",
|
||||
"panda is a national animal of which country",
|
||||
"what is the name of manchester united stadium",
|
||||
]
|
||||
input_dict = tokenizer(input_strings)
|
||||
self.assertIsNotNone(input_dict)
|
||||
|
||||
@slow
|
||||
def test_pretrained_sequence_nq_tokenizer(self):
|
||||
tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq")
|
||||
input_strings = [
|
||||
"who got the first nobel prize in physics",
|
||||
"when is the next deadpool movie being released",
|
||||
"which mode is used for short wave broadcast service",
|
||||
"who is the owner of reading football club",
|
||||
"when is the next scandal episode coming out",
|
||||
"when is the last time the philadelphia won the superbowl",
|
||||
"what is the most current adobe flash player version",
|
||||
"how many episodes are there in dragon ball z",
|
||||
"what is the first step in the evolution of the eye",
|
||||
"where is gall bladder situated in human body",
|
||||
"what is the main mineral in lithium batteries",
|
||||
"who is the president of usa right now",
|
||||
"where do the greasers live in the outsiders",
|
||||
"panda is a national animal of which country",
|
||||
"what is the name of manchester united stadium",
|
||||
]
|
||||
input_dict = tokenizer(input_strings)
|
||||
self.assertIsNotNone(input_dict)
|
||||
|
||||
Reference in New Issue
Block a user