fix rag retriever save pretrained (#7399)
This commit is contained in:
committed by
GitHub
parent
1a14687e6f
commit
2c8ecdf8a8
@@ -168,6 +168,11 @@ class RagRetrieverTest(TestCase):
|
||||
self.assertEqual(doc_dicts[1]["id"][0], "0") # max inner product is reached with first doc
|
||||
self.assertListEqual(doc_ids.tolist(), [[1], [0]])
|
||||
|
||||
def test_save_and_from_pretrained(self):
|
||||
retriever = self.get_dummy_hf_index_retriever()
|
||||
with tempfile.TemporaryDirectory() as tmp_dirname:
|
||||
retriever.save_pretrained(tmp_dirname)
|
||||
|
||||
def test_legacy_index_retriever_retrieve(self):
|
||||
n_docs = 1
|
||||
retriever = self.get_dummy_legacy_index_retriever()
|
||||
|
||||
Reference in New Issue
Block a user