fix rag retriever save pretrained (#7399)
This commit is contained in:
committed by
GitHub
parent
1a14687e6f
commit
2c8ecdf8a8
@@ -312,8 +312,8 @@ class RagRetriever:
|
|||||||
def save_pretrained(self, save_directory):
|
def save_pretrained(self, save_directory):
|
||||||
self.config.save_pretrained(save_directory)
|
self.config.save_pretrained(save_directory)
|
||||||
rag_tokenizer = RagTokenizer(
|
rag_tokenizer = RagTokenizer(
|
||||||
question_encoder_tokenizer=self.question_encoder_tokenizer,
|
question_encoder=self.question_encoder_tokenizer,
|
||||||
generator_tokenizer=self.generator_tokenizer,
|
generator=self.generator_tokenizer,
|
||||||
)
|
)
|
||||||
rag_tokenizer.save_pretrained(save_directory)
|
rag_tokenizer.save_pretrained(save_directory)
|
||||||
|
|
||||||
|
|||||||
@@ -168,6 +168,11 @@ class RagRetrieverTest(TestCase):
|
|||||||
self.assertEqual(doc_dicts[1]["id"][0], "0") # max inner product is reached with first doc
|
self.assertEqual(doc_dicts[1]["id"][0], "0") # max inner product is reached with first doc
|
||||||
self.assertListEqual(doc_ids.tolist(), [[1], [0]])
|
self.assertListEqual(doc_ids.tolist(), [[1], [0]])
|
||||||
|
|
||||||
|
def test_save_and_from_pretrained(self):
|
||||||
|
retriever = self.get_dummy_hf_index_retriever()
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dirname:
|
||||||
|
retriever.save_pretrained(tmp_dirname)
|
||||||
|
|
||||||
def test_legacy_index_retriever_retrieve(self):
|
def test_legacy_index_retriever_retrieve(self):
|
||||||
n_docs = 1
|
n_docs = 1
|
||||||
retriever = self.get_dummy_legacy_index_retriever()
|
retriever = self.get_dummy_legacy_index_retriever()
|
||||||
|
|||||||
Reference in New Issue
Block a user