RAG-2nd2end-revamp (#11893)
* initial * code quality test * code quality * added test functions in test_modeling_rag.py and test_retrieval_rag.py to test end2end retreiver * minor change in test_modeling_rag * fixed tests * Update examples/research_projects/rag-end2end-retriever/README.md typo corrected as suggested by lhoestq Co-authored-by: Quentin Lhoest <42851186+lhoestq@users.noreply.github.com> * Update examples/research_projects/rag-end2end-retriever/finetune_rag.py type change suggested by lhoestq Co-authored-by: Quentin Lhoest <42851186+lhoestq@users.noreply.github.com> * Update src/transformers/models/rag/retrieval_rag.py Adding this change as mentioned by lhoestq. Co-authored-by: Quentin Lhoest <42851186+lhoestq@users.noreply.github.com> * completed the minor changes suggested by the reviewers Co-authored-by: Quentin Lhoest <42851186+lhoestq@users.noreply.github.com>
This commit is contained in:
@@ -28,7 +28,7 @@ from transformers.models.bart.configuration_bart import BartConfig
|
||||
from transformers.models.bart.tokenization_bart import BartTokenizer
|
||||
from transformers.models.bert.tokenization_bert import VOCAB_FILES_NAMES as DPR_VOCAB_FILES_NAMES
|
||||
from transformers.models.dpr.configuration_dpr import DPRConfig
|
||||
from transformers.models.dpr.tokenization_dpr import DPRQuestionEncoderTokenizer
|
||||
from transformers.models.dpr.tokenization_dpr import DPRContextEncoderTokenizer, DPRQuestionEncoderTokenizer
|
||||
from transformers.models.rag.configuration_rag import RagConfig
|
||||
from transformers.models.rag.retrieval_rag import CustomHFIndex, RagRetriever
|
||||
from transformers.models.roberta.tokenization_roberta import VOCAB_FILES_NAMES as BART_VOCAB_FILES_NAMES
|
||||
@@ -115,6 +115,9 @@ class RagRetrieverTest(TestCase):
|
||||
def get_dpr_tokenizer(self) -> DPRQuestionEncoderTokenizer:
|
||||
return DPRQuestionEncoderTokenizer.from_pretrained(os.path.join(self.tmpdirname, "dpr_tokenizer"))
|
||||
|
||||
def get_dpr_ctx_encoder_tokenizer(self) -> DPRContextEncoderTokenizer:
|
||||
return DPRContextEncoderTokenizer.from_pretrained(os.path.join(self.tmpdirname, "dpr_tokenizer"))
|
||||
|
||||
def get_bart_tokenizer(self) -> BartTokenizer:
|
||||
return BartTokenizer.from_pretrained(os.path.join(self.tmpdirname, "bart_tokenizer"))
|
||||
|
||||
@@ -359,3 +362,26 @@ class RagRetrieverTest(TestCase):
|
||||
self.assertIsInstance(context_input_ids, torch.Tensor)
|
||||
self.assertIsInstance(context_attention_mask, torch.Tensor)
|
||||
self.assertIsInstance(retrieved_doc_embeds, torch.Tensor)
|
||||
|
||||
@require_torch
|
||||
@require_tokenizers
|
||||
@require_sentencepiece
|
||||
def test_custom_hf_index_end2end_retriever_call(self):
|
||||
|
||||
context_encoder_tokenizer = self.get_dpr_ctx_encoder_tokenizer()
|
||||
n_docs = 1
|
||||
retriever = self.get_dummy_custom_hf_index_retriever(from_disk=False)
|
||||
retriever.set_ctx_encoder_tokenizer(context_encoder_tokenizer)
|
||||
|
||||
question_input_ids = [[5, 7], [10, 11]]
|
||||
hidden_states = np.array(
|
||||
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
|
||||
)
|
||||
out = retriever(question_input_ids, hidden_states, prefix=retriever.config.generator.prefix, n_docs=n_docs)
|
||||
|
||||
self.assertEqual(
|
||||
len(out), 6
|
||||
) # check whether the retriever output consist of 6 attributes including tokenized docs
|
||||
self.assertEqual(
|
||||
all(k in out for k in ("tokenized_doc_ids", "tokenized_doc_attention_mask")), True
|
||||
) # check for doc token related keys in dictionary.
|
||||
|
||||
Reference in New Issue
Block a user