Disallow pickle.load unless TRUST_REMOTE_CODE=True (#27776)
* fix * fix * Use TRUST_REMOTE_CODE * fix doc * fix --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -14,7 +14,6 @@
|
||||
|
||||
import json
|
||||
import os
|
||||
import pickle
|
||||
import shutil
|
||||
import tempfile
|
||||
from unittest import TestCase
|
||||
@@ -174,37 +173,6 @@ class RagRetrieverTest(TestCase):
|
||||
)
|
||||
return retriever
|
||||
|
||||
def get_dummy_legacy_index_retriever(self):
|
||||
dataset = Dataset.from_dict(
|
||||
{
|
||||
"id": ["0", "1"],
|
||||
"text": ["foo", "bar"],
|
||||
"title": ["Foo", "Bar"],
|
||||
"embeddings": [np.ones(self.retrieval_vector_size + 1), 2 * np.ones(self.retrieval_vector_size + 1)],
|
||||
}
|
||||
)
|
||||
dataset.add_faiss_index("embeddings", string_factory="Flat", metric_type=faiss.METRIC_INNER_PRODUCT)
|
||||
|
||||
index_file_name = os.path.join(self.tmpdirname, "hf_bert_base.hnswSQ8_correct_phi_128.c_index")
|
||||
dataset.save_faiss_index("embeddings", index_file_name + ".index.dpr")
|
||||
pickle.dump(dataset["id"], open(index_file_name + ".index_meta.dpr", "wb"))
|
||||
|
||||
passages_file_name = os.path.join(self.tmpdirname, "psgs_w100.tsv.pkl")
|
||||
passages = {sample["id"]: [sample["text"], sample["title"]] for sample in dataset}
|
||||
pickle.dump(passages, open(passages_file_name, "wb"))
|
||||
|
||||
config = RagConfig(
|
||||
retrieval_vector_size=self.retrieval_vector_size,
|
||||
question_encoder=DPRConfig().to_dict(),
|
||||
generator=BartConfig().to_dict(),
|
||||
index_name="legacy",
|
||||
index_path=self.tmpdirname,
|
||||
)
|
||||
retriever = RagRetriever(
|
||||
config, question_encoder_tokenizer=self.get_dpr_tokenizer(), generator_tokenizer=self.get_bart_tokenizer()
|
||||
)
|
||||
return retriever
|
||||
|
||||
def test_canonical_hf_index_retriever_retrieve(self):
|
||||
n_docs = 1
|
||||
retriever = self.get_dummy_canonical_hf_index_retriever()
|
||||
@@ -288,33 +256,6 @@ class RagRetrieverTest(TestCase):
|
||||
out = retriever.retrieve(hidden_states, n_docs=1)
|
||||
self.assertTrue(out is not None)
|
||||
|
||||
def test_legacy_index_retriever_retrieve(self):
|
||||
n_docs = 1
|
||||
retriever = self.get_dummy_legacy_index_retriever()
|
||||
hidden_states = np.array(
|
||||
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
|
||||
)
|
||||
retrieved_doc_embeds, doc_ids, doc_dicts = retriever.retrieve(hidden_states, n_docs=n_docs)
|
||||
self.assertEqual(retrieved_doc_embeds.shape, (2, n_docs, self.retrieval_vector_size))
|
||||
self.assertEqual(len(doc_dicts), 2)
|
||||
self.assertEqual(sorted(doc_dicts[0]), ["text", "title"])
|
||||
self.assertEqual(len(doc_dicts[0]["text"]), n_docs)
|
||||
self.assertEqual(doc_dicts[0]["text"][0], "bar") # max inner product is reached with second doc
|
||||
self.assertEqual(doc_dicts[1]["text"][0], "foo") # max inner product is reached with first doc
|
||||
self.assertListEqual(doc_ids.tolist(), [[1], [0]])
|
||||
|
||||
def test_legacy_hf_index_retriever_save_and_from_pretrained(self):
|
||||
retriever = self.get_dummy_legacy_index_retriever()
|
||||
with tempfile.TemporaryDirectory() as tmp_dirname:
|
||||
retriever.save_pretrained(tmp_dirname)
|
||||
retriever = RagRetriever.from_pretrained(tmp_dirname)
|
||||
self.assertIsInstance(retriever, RagRetriever)
|
||||
hidden_states = np.array(
|
||||
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
|
||||
)
|
||||
out = retriever.retrieve(hidden_states, n_docs=1)
|
||||
self.assertTrue(out is not None)
|
||||
|
||||
@require_torch
|
||||
@require_tokenizers
|
||||
@require_sentencepiece
|
||||
|
||||
Reference in New Issue
Block a user