Disallow pickle.load unless TRUST_REMOTE_CODE=True (#27776)
* fix * fix * Use TRUST_REMOTE_CODE * fix doc * fix --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -22,11 +22,17 @@ This model is in maintenance mode only, so we won't accept any new PRs changing
|
|||||||
|
|
||||||
We recommend switching to more recent models for improved security.
|
We recommend switching to more recent models for improved security.
|
||||||
|
|
||||||
In case you would still like to use `TransfoXL` in your experiments, we recommend using the [Hub checkpoint](https://huggingface.co/transfo-xl-wt103) with a specific revision to ensure you are downloading safe files from the Hub:
|
In case you would still like to use `TransfoXL` in your experiments, we recommend using the [Hub checkpoint](https://huggingface.co/transfo-xl-wt103) with a specific revision to ensure you are downloading safe files from the Hub.
|
||||||
|
|
||||||
```
|
You will need to set the environment variable `TRUST_REMOTE_CODE` to `True` in order to allow the
|
||||||
|
usage of `pickle.load()`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import os
|
||||||
from transformers import TransfoXLTokenizer, TransfoXLLMHeadModel
|
from transformers import TransfoXLTokenizer, TransfoXLLMHeadModel
|
||||||
|
|
||||||
|
os.environ["TRUST_REMOTE_CODE"] = "True"
|
||||||
|
|
||||||
checkpoint = 'transfo-xl-wt103'
|
checkpoint = 'transfo-xl-wt103'
|
||||||
revision = '40a186da79458c9f9de846edfaea79c412137f97'
|
revision = '40a186da79458c9f9de846edfaea79c412137f97'
|
||||||
|
|
||||||
|
|||||||
@@ -34,6 +34,7 @@ from ....utils import (
|
|||||||
is_torch_available,
|
is_torch_available,
|
||||||
logging,
|
logging,
|
||||||
requires_backends,
|
requires_backends,
|
||||||
|
strtobool,
|
||||||
torch_only_method,
|
torch_only_method,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -212,6 +213,14 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
|
|||||||
vocab_dict = None
|
vocab_dict = None
|
||||||
if pretrained_vocab_file is not None:
|
if pretrained_vocab_file is not None:
|
||||||
# Priority on pickle files (support PyTorch and TF)
|
# Priority on pickle files (support PyTorch and TF)
|
||||||
|
if not strtobool(os.environ.get("TRUST_REMOTE_CODE", "False")):
|
||||||
|
raise ValueError(
|
||||||
|
"This part uses `pickle.load` which is insecure and will execute arbitrary code that is "
|
||||||
|
"potentially malicious. It's recommended to never unpickle data that could have come from an "
|
||||||
|
"untrusted source, or that could have been tampered with. If you already verified the pickle "
|
||||||
|
"data and decided to use it, you can set the environment variable "
|
||||||
|
"`TRUST_REMOTE_CODE` to `True` to allow it."
|
||||||
|
)
|
||||||
with open(pretrained_vocab_file, "rb") as f:
|
with open(pretrained_vocab_file, "rb") as f:
|
||||||
vocab_dict = pickle.load(f)
|
vocab_dict = pickle.load(f)
|
||||||
|
|
||||||
@@ -790,6 +799,13 @@ def get_lm_corpus(datadir, dataset):
|
|||||||
corpus = torch.load(fn_pickle)
|
corpus = torch.load(fn_pickle)
|
||||||
elif os.path.exists(fn):
|
elif os.path.exists(fn):
|
||||||
logger.info("Loading cached dataset from pickle...")
|
logger.info("Loading cached dataset from pickle...")
|
||||||
|
if not strtobool(os.environ.get("TRUST_REMOTE_CODE", "False")):
|
||||||
|
raise ValueError(
|
||||||
|
"This part uses `pickle.load` which is insecure and will execute arbitrary code that is potentially "
|
||||||
|
"malicious. It's recommended to never unpickle data that could have come from an untrusted source, or "
|
||||||
|
"that could have been tampered with. If you already verified the pickle data and decided to use it, "
|
||||||
|
"you can set the environment variable `TRUST_REMOTE_CODE` to `True` to allow it."
|
||||||
|
)
|
||||||
with open(fn, "rb") as fp:
|
with open(fn, "rb") as fp:
|
||||||
corpus = pickle.load(fp)
|
corpus = pickle.load(fp)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ import numpy as np
|
|||||||
|
|
||||||
from ...tokenization_utils import PreTrainedTokenizer
|
from ...tokenization_utils import PreTrainedTokenizer
|
||||||
from ...tokenization_utils_base import BatchEncoding
|
from ...tokenization_utils_base import BatchEncoding
|
||||||
from ...utils import cached_file, is_datasets_available, is_faiss_available, logging, requires_backends
|
from ...utils import cached_file, is_datasets_available, is_faiss_available, logging, requires_backends, strtobool
|
||||||
from .configuration_rag import RagConfig
|
from .configuration_rag import RagConfig
|
||||||
from .tokenization_rag import RagTokenizer
|
from .tokenization_rag import RagTokenizer
|
||||||
|
|
||||||
@@ -131,6 +131,13 @@ class LegacyIndex(Index):
|
|||||||
def _load_passages(self):
|
def _load_passages(self):
|
||||||
logger.info(f"Loading passages from {self.index_path}")
|
logger.info(f"Loading passages from {self.index_path}")
|
||||||
passages_path = self._resolve_path(self.index_path, self.PASSAGE_FILENAME)
|
passages_path = self._resolve_path(self.index_path, self.PASSAGE_FILENAME)
|
||||||
|
if not strtobool(os.environ.get("TRUST_REMOTE_CODE", "False")):
|
||||||
|
raise ValueError(
|
||||||
|
"This part uses `pickle.load` which is insecure and will execute arbitrary code that is potentially "
|
||||||
|
"malicious. It's recommended to never unpickle data that could have come from an untrusted source, or "
|
||||||
|
"that could have been tampered with. If you already verified the pickle data and decided to use it, "
|
||||||
|
"you can set the environment variable `TRUST_REMOTE_CODE` to `True` to allow it."
|
||||||
|
)
|
||||||
with open(passages_path, "rb") as passages_file:
|
with open(passages_path, "rb") as passages_file:
|
||||||
passages = pickle.load(passages_file)
|
passages = pickle.load(passages_file)
|
||||||
return passages
|
return passages
|
||||||
@@ -140,6 +147,13 @@ class LegacyIndex(Index):
|
|||||||
resolved_index_path = self._resolve_path(self.index_path, self.INDEX_FILENAME + ".index.dpr")
|
resolved_index_path = self._resolve_path(self.index_path, self.INDEX_FILENAME + ".index.dpr")
|
||||||
self.index = faiss.read_index(resolved_index_path)
|
self.index = faiss.read_index(resolved_index_path)
|
||||||
resolved_meta_path = self._resolve_path(self.index_path, self.INDEX_FILENAME + ".index_meta.dpr")
|
resolved_meta_path = self._resolve_path(self.index_path, self.INDEX_FILENAME + ".index_meta.dpr")
|
||||||
|
if not strtobool(os.environ.get("TRUST_REMOTE_CODE", "False")):
|
||||||
|
raise ValueError(
|
||||||
|
"This part uses `pickle.load` which is insecure and will execute arbitrary code that is potentially "
|
||||||
|
"malicious. It's recommended to never unpickle data that could have come from an untrusted source, or "
|
||||||
|
"that could have been tampered with. If you already verified the pickle data and decided to use it, "
|
||||||
|
"you can set the environment variable `TRUST_REMOTE_CODE` to `True` to allow it."
|
||||||
|
)
|
||||||
with open(resolved_meta_path, "rb") as metadata_file:
|
with open(resolved_meta_path, "rb") as metadata_file:
|
||||||
self.index_id_to_db_id = pickle.load(metadata_file)
|
self.index_id_to_db_id = pickle.load(metadata_file)
|
||||||
assert (
|
assert (
|
||||||
|
|||||||
@@ -14,7 +14,6 @@
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import pickle
|
|
||||||
import shutil
|
import shutil
|
||||||
import tempfile
|
import tempfile
|
||||||
from unittest import TestCase
|
from unittest import TestCase
|
||||||
@@ -174,37 +173,6 @@ class RagRetrieverTest(TestCase):
|
|||||||
)
|
)
|
||||||
return retriever
|
return retriever
|
||||||
|
|
||||||
def get_dummy_legacy_index_retriever(self):
|
|
||||||
dataset = Dataset.from_dict(
|
|
||||||
{
|
|
||||||
"id": ["0", "1"],
|
|
||||||
"text": ["foo", "bar"],
|
|
||||||
"title": ["Foo", "Bar"],
|
|
||||||
"embeddings": [np.ones(self.retrieval_vector_size + 1), 2 * np.ones(self.retrieval_vector_size + 1)],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
dataset.add_faiss_index("embeddings", string_factory="Flat", metric_type=faiss.METRIC_INNER_PRODUCT)
|
|
||||||
|
|
||||||
index_file_name = os.path.join(self.tmpdirname, "hf_bert_base.hnswSQ8_correct_phi_128.c_index")
|
|
||||||
dataset.save_faiss_index("embeddings", index_file_name + ".index.dpr")
|
|
||||||
pickle.dump(dataset["id"], open(index_file_name + ".index_meta.dpr", "wb"))
|
|
||||||
|
|
||||||
passages_file_name = os.path.join(self.tmpdirname, "psgs_w100.tsv.pkl")
|
|
||||||
passages = {sample["id"]: [sample["text"], sample["title"]] for sample in dataset}
|
|
||||||
pickle.dump(passages, open(passages_file_name, "wb"))
|
|
||||||
|
|
||||||
config = RagConfig(
|
|
||||||
retrieval_vector_size=self.retrieval_vector_size,
|
|
||||||
question_encoder=DPRConfig().to_dict(),
|
|
||||||
generator=BartConfig().to_dict(),
|
|
||||||
index_name="legacy",
|
|
||||||
index_path=self.tmpdirname,
|
|
||||||
)
|
|
||||||
retriever = RagRetriever(
|
|
||||||
config, question_encoder_tokenizer=self.get_dpr_tokenizer(), generator_tokenizer=self.get_bart_tokenizer()
|
|
||||||
)
|
|
||||||
return retriever
|
|
||||||
|
|
||||||
def test_canonical_hf_index_retriever_retrieve(self):
|
def test_canonical_hf_index_retriever_retrieve(self):
|
||||||
n_docs = 1
|
n_docs = 1
|
||||||
retriever = self.get_dummy_canonical_hf_index_retriever()
|
retriever = self.get_dummy_canonical_hf_index_retriever()
|
||||||
@@ -288,33 +256,6 @@ class RagRetrieverTest(TestCase):
|
|||||||
out = retriever.retrieve(hidden_states, n_docs=1)
|
out = retriever.retrieve(hidden_states, n_docs=1)
|
||||||
self.assertTrue(out is not None)
|
self.assertTrue(out is not None)
|
||||||
|
|
||||||
def test_legacy_index_retriever_retrieve(self):
|
|
||||||
n_docs = 1
|
|
||||||
retriever = self.get_dummy_legacy_index_retriever()
|
|
||||||
hidden_states = np.array(
|
|
||||||
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
|
|
||||||
)
|
|
||||||
retrieved_doc_embeds, doc_ids, doc_dicts = retriever.retrieve(hidden_states, n_docs=n_docs)
|
|
||||||
self.assertEqual(retrieved_doc_embeds.shape, (2, n_docs, self.retrieval_vector_size))
|
|
||||||
self.assertEqual(len(doc_dicts), 2)
|
|
||||||
self.assertEqual(sorted(doc_dicts[0]), ["text", "title"])
|
|
||||||
self.assertEqual(len(doc_dicts[0]["text"]), n_docs)
|
|
||||||
self.assertEqual(doc_dicts[0]["text"][0], "bar") # max inner product is reached with second doc
|
|
||||||
self.assertEqual(doc_dicts[1]["text"][0], "foo") # max inner product is reached with first doc
|
|
||||||
self.assertListEqual(doc_ids.tolist(), [[1], [0]])
|
|
||||||
|
|
||||||
def test_legacy_hf_index_retriever_save_and_from_pretrained(self):
|
|
||||||
retriever = self.get_dummy_legacy_index_retriever()
|
|
||||||
with tempfile.TemporaryDirectory() as tmp_dirname:
|
|
||||||
retriever.save_pretrained(tmp_dirname)
|
|
||||||
retriever = RagRetriever.from_pretrained(tmp_dirname)
|
|
||||||
self.assertIsInstance(retriever, RagRetriever)
|
|
||||||
hidden_states = np.array(
|
|
||||||
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
|
|
||||||
)
|
|
||||||
out = retriever.retrieve(hidden_states, n_docs=1)
|
|
||||||
self.assertTrue(out is not None)
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@require_tokenizers
|
@require_tokenizers
|
||||||
@require_sentencepiece
|
@require_sentencepiece
|
||||||
|
|||||||
Reference in New Issue
Block a user