RAG-2nd2end-revamp (#11893)
* initial * code quality test * code quality * added test functions in test_modeling_rag.py and test_retrieval_rag.py to test end2end retreiver * minor change in test_modeling_rag * fixed tests * Update examples/research_projects/rag-end2end-retriever/README.md typo corrected as suggested by lhoestq Co-authored-by: Quentin Lhoest <42851186+lhoestq@users.noreply.github.com> * Update examples/research_projects/rag-end2end-retriever/finetune_rag.py type change suggested by lhoestq Co-authored-by: Quentin Lhoest <42851186+lhoestq@users.noreply.github.com> * Update src/transformers/models/rag/retrieval_rag.py Adding this change as mentioned by lhoestq. Co-authored-by: Quentin Lhoest <42851186+lhoestq@users.noreply.github.com> * completed the minor changes suggested by the reviewers Co-authored-by: Quentin Lhoest <42851186+lhoestq@users.noreply.github.com>
This commit is contained in:
@@ -26,7 +26,7 @@ import numpy as np
|
||||
from transformers import BartTokenizer, T5Tokenizer
|
||||
from transformers.file_utils import cached_property, is_datasets_available, is_faiss_available, is_torch_available
|
||||
from transformers.models.bert.tokenization_bert import VOCAB_FILES_NAMES as DPR_VOCAB_FILES_NAMES
|
||||
from transformers.models.dpr.tokenization_dpr import DPRQuestionEncoderTokenizer
|
||||
from transformers.models.dpr.tokenization_dpr import DPRContextEncoderTokenizer, DPRQuestionEncoderTokenizer
|
||||
from transformers.models.roberta.tokenization_roberta import VOCAB_FILES_NAMES as BART_VOCAB_FILES_NAMES
|
||||
from transformers.testing_utils import (
|
||||
require_sentencepiece,
|
||||
@@ -55,6 +55,7 @@ if is_torch_available() and is_datasets_available() and is_faiss_available():
|
||||
AutoConfig,
|
||||
AutoModel,
|
||||
AutoModelForSeq2SeqLM,
|
||||
DPRContextEncoder,
|
||||
RagConfig,
|
||||
RagModel,
|
||||
RagRetriever,
|
||||
@@ -179,6 +180,10 @@ class RagTestMixin:
|
||||
def dpr_tokenizer(self) -> DPRQuestionEncoderTokenizer:
|
||||
return DPRQuestionEncoderTokenizer.from_pretrained(os.path.join(self.tmpdirname, "dpr_tokenizer"))
|
||||
|
||||
@cached_property
|
||||
def dpr_ctx_encoder_tokenizer(self) -> DPRContextEncoderTokenizer:
|
||||
return DPRContextEncoderTokenizer.from_pretrained(os.path.join(self.tmpdirname, "dpr_tokenizer"))
|
||||
|
||||
@cached_property
|
||||
def bart_tokenizer(self) -> BartTokenizer:
|
||||
return BartTokenizer.from_pretrained(os.path.join(self.tmpdirname, "bart_tokenizer"))
|
||||
@@ -246,6 +251,46 @@ class RagTestMixin:
|
||||
# doc scores
|
||||
self.assertEqual(outputs.doc_scores.shape, (input_ids.shape[0], self.n_docs))
|
||||
|
||||
def check_model_with_end2end_retriever(
|
||||
self, config, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, **kwargs
|
||||
):
|
||||
self.assertIsNotNone(config.question_encoder)
|
||||
self.assertIsNotNone(config.generator)
|
||||
|
||||
context_encoder_tokenizer = self.dpr_ctx_encoder_tokenizer
|
||||
dpr_context_encoder = DPRContextEncoder(config.question_encoder) # dpr is a twin tower
|
||||
|
||||
retriever = self.get_retriever(config)
|
||||
retriever.set_ctx_encoder_tokenizer(context_encoder_tokenizer) # setting the ctx_encoder_tokenizer.
|
||||
|
||||
for model_class in [RagTokenForGeneration, RagSequenceForGeneration]:
|
||||
model = model_class(config, retriever=retriever)
|
||||
model.set_context_encoder_for_training(dpr_context_encoder) # set the context_encoder for training
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
self.assertTrue(model.config.is_encoder_decoder)
|
||||
|
||||
outputs = model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
)
|
||||
|
||||
# logits
|
||||
self.assertEqual(
|
||||
outputs.logits.shape,
|
||||
(self.n_docs * decoder_input_ids.shape[0], decoder_input_ids.shape[1], config.generator.vocab_size),
|
||||
)
|
||||
# generator encoder last hidden states
|
||||
self.assertEqual(
|
||||
outputs.generator_enc_last_hidden_state.shape,
|
||||
(self.n_docs * decoder_input_ids.shape[0], self.max_combined_length, config.generator.hidden_size),
|
||||
)
|
||||
# doc scores
|
||||
self.assertEqual(outputs.doc_scores.shape, (input_ids.shape[0], self.n_docs))
|
||||
|
||||
def check_model_generate_from_context_input_ids(
|
||||
self, config, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, **kwargs
|
||||
):
|
||||
@@ -538,6 +583,10 @@ class RagTestMixin:
|
||||
inputs_dict = self.config_and_inputs
|
||||
self.check_model_with_retriever(**inputs_dict)
|
||||
|
||||
def test_model_with_end2end_retriever(self):
|
||||
inputs_dict = self.config_and_inputs
|
||||
self.check_model_with_end2end_retriever(**inputs_dict)
|
||||
|
||||
def test_model_without_retriever(self):
|
||||
inputs_dict = self.config_and_inputs
|
||||
self.check_model_without_retriever(**inputs_dict)
|
||||
|
||||
@@ -28,7 +28,7 @@ from transformers.models.bart.configuration_bart import BartConfig
|
||||
from transformers.models.bart.tokenization_bart import BartTokenizer
|
||||
from transformers.models.bert.tokenization_bert import VOCAB_FILES_NAMES as DPR_VOCAB_FILES_NAMES
|
||||
from transformers.models.dpr.configuration_dpr import DPRConfig
|
||||
from transformers.models.dpr.tokenization_dpr import DPRQuestionEncoderTokenizer
|
||||
from transformers.models.dpr.tokenization_dpr import DPRContextEncoderTokenizer, DPRQuestionEncoderTokenizer
|
||||
from transformers.models.rag.configuration_rag import RagConfig
|
||||
from transformers.models.rag.retrieval_rag import CustomHFIndex, RagRetriever
|
||||
from transformers.models.roberta.tokenization_roberta import VOCAB_FILES_NAMES as BART_VOCAB_FILES_NAMES
|
||||
@@ -115,6 +115,9 @@ class RagRetrieverTest(TestCase):
|
||||
def get_dpr_tokenizer(self) -> DPRQuestionEncoderTokenizer:
|
||||
return DPRQuestionEncoderTokenizer.from_pretrained(os.path.join(self.tmpdirname, "dpr_tokenizer"))
|
||||
|
||||
def get_dpr_ctx_encoder_tokenizer(self) -> DPRContextEncoderTokenizer:
|
||||
return DPRContextEncoderTokenizer.from_pretrained(os.path.join(self.tmpdirname, "dpr_tokenizer"))
|
||||
|
||||
def get_bart_tokenizer(self) -> BartTokenizer:
|
||||
return BartTokenizer.from_pretrained(os.path.join(self.tmpdirname, "bart_tokenizer"))
|
||||
|
||||
@@ -359,3 +362,26 @@ class RagRetrieverTest(TestCase):
|
||||
self.assertIsInstance(context_input_ids, torch.Tensor)
|
||||
self.assertIsInstance(context_attention_mask, torch.Tensor)
|
||||
self.assertIsInstance(retrieved_doc_embeds, torch.Tensor)
|
||||
|
||||
@require_torch
|
||||
@require_tokenizers
|
||||
@require_sentencepiece
|
||||
def test_custom_hf_index_end2end_retriever_call(self):
|
||||
|
||||
context_encoder_tokenizer = self.get_dpr_ctx_encoder_tokenizer()
|
||||
n_docs = 1
|
||||
retriever = self.get_dummy_custom_hf_index_retriever(from_disk=False)
|
||||
retriever.set_ctx_encoder_tokenizer(context_encoder_tokenizer)
|
||||
|
||||
question_input_ids = [[5, 7], [10, 11]]
|
||||
hidden_states = np.array(
|
||||
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
|
||||
)
|
||||
out = retriever(question_input_ids, hidden_states, prefix=retriever.config.generator.prefix, n_docs=n_docs)
|
||||
|
||||
self.assertEqual(
|
||||
len(out), 6
|
||||
) # check whether the retriever output consist of 6 attributes including tokenized docs
|
||||
self.assertEqual(
|
||||
all(k in out for k in ("tokenized_doc_ids", "tokenized_doc_attention_mask")), True
|
||||
) # check for doc token related keys in dictionary.
|
||||
|
||||
Reference in New Issue
Block a user