From 2c8ecdf8a87019c438262d8c692e1bdffe05149f Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 25 Sep 2020 19:47:12 +0200 Subject: [PATCH] fix rag retriever save pretrained (#7399) --- src/transformers/retrieval_rag.py | 4 ++-- tests/test_retrieval_rag.py | 5 +++++ 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/transformers/retrieval_rag.py b/src/transformers/retrieval_rag.py index bd102ee463..6ed0639e3c 100644 --- a/src/transformers/retrieval_rag.py +++ b/src/transformers/retrieval_rag.py @@ -312,8 +312,8 @@ class RagRetriever: def save_pretrained(self, save_directory): self.config.save_pretrained(save_directory) rag_tokenizer = RagTokenizer( - question_encoder_tokenizer=self.question_encoder_tokenizer, - generator_tokenizer=self.generator_tokenizer, + question_encoder=self.question_encoder_tokenizer, + generator=self.generator_tokenizer, ) rag_tokenizer.save_pretrained(save_directory) diff --git a/tests/test_retrieval_rag.py b/tests/test_retrieval_rag.py index 4f45ac2df8..1fa1cadb8f 100644 --- a/tests/test_retrieval_rag.py +++ b/tests/test_retrieval_rag.py @@ -168,6 +168,11 @@ class RagRetrieverTest(TestCase): self.assertEqual(doc_dicts[1]["id"][0], "0") # max inner product is reached with first doc self.assertListEqual(doc_ids.tolist(), [[1], [0]]) + def test_save_and_from_pretrained(self): + retriever = self.get_dummy_hf_index_retriever() + with tempfile.TemporaryDirectory() as tmp_dirname: + retriever.save_pretrained(tmp_dirname) + def test_legacy_index_retriever_retrieve(self): n_docs = 1 retriever = self.get_dummy_legacy_index_retriever()