[RAG] Propagating of n_docs as parameter to all RagModel's related functions (#7891)
* Propagating n_docs as parameter to all RagModel's related functions that defaults to self.config.n_docs * Making n_docs parameter's default value to None in marginalize function * Fixing code quality issues * Handle the special case when generator is of T5PreTrainedModel instance type. T5PreTrainedModel do not have n_docs as parameter * T5PreTrainedModel do not have n_docs as parameter * Addressing review comment Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Correcting comment by addressing review comment * Adding assert statement verifying that n_docs is correctly set. n_docs should be the same for both retriever and generator. * Fixing flake8 reported issue * Correcting test datasets for rag * Using doc_scores instead of context_input_ids to check assert as in RagSequenceForGeneration context_input_ids can be null * doc_scores second dimension have number of retrieved docs * Changing assert comment * Apply suggestions from code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -82,7 +82,7 @@ def require_retrieval(test_case):
|
||||
|
||||
"""
|
||||
if not (is_torch_available() and is_datasets_available() and is_faiss_available()):
|
||||
test_case = unittest.skip("test requires PyTorch")(test_case)
|
||||
test_case = unittest.skip("test requires PyTorch, datasets and faiss")(test_case)
|
||||
return test_case
|
||||
|
||||
|
||||
@@ -98,7 +98,7 @@ class RagTestMixin:
|
||||
)
|
||||
|
||||
retrieval_vector_size = 32
|
||||
n_docs = 2
|
||||
n_docs = 3
|
||||
max_combined_length = 16
|
||||
|
||||
def setUp(self):
|
||||
@@ -186,10 +186,14 @@ class RagTestMixin:
|
||||
def get_retriever(self, config):
|
||||
dataset = Dataset.from_dict(
|
||||
{
|
||||
"id": ["0", "1"],
|
||||
"text": ["foo", "bar"],
|
||||
"title": ["Foo", "Bar"],
|
||||
"embeddings": [np.ones(self.retrieval_vector_size), 2 * np.ones(self.retrieval_vector_size)],
|
||||
"id": ["0", "1", "3"],
|
||||
"text": ["foo", "bar", "qux"],
|
||||
"title": ["Foo", "Bar", "Qux"],
|
||||
"embeddings": [
|
||||
np.ones(self.retrieval_vector_size),
|
||||
2 * np.ones(self.retrieval_vector_size),
|
||||
3 * np.ones(self.retrieval_vector_size),
|
||||
],
|
||||
}
|
||||
)
|
||||
dataset.add_faiss_index("embeddings", string_factory="Flat", metric_type=faiss.METRIC_INNER_PRODUCT)
|
||||
@@ -315,6 +319,125 @@ class RagTestMixin:
|
||||
# doc scores
|
||||
self.assertEqual(outputs.doc_scores.shape, (input_ids.shape[0], self.n_docs))
|
||||
|
||||
def check_model_custom_n_docs(
|
||||
self, config, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, n_docs, **kwargs
|
||||
):
|
||||
self.assertIsNotNone(config.question_encoder)
|
||||
self.assertIsNotNone(config.generator)
|
||||
|
||||
retriever = self.get_retriever(config)
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config).to(torch_device)
|
||||
model.eval()
|
||||
self.assertTrue(model.config.is_encoder_decoder)
|
||||
|
||||
question_hidden_states = model.question_encoder(input_ids, attention_mask=attention_mask)[0]
|
||||
|
||||
out = retriever(
|
||||
input_ids,
|
||||
question_hidden_states.cpu().detach().to(torch.float32).numpy(),
|
||||
prefix=config.generator.prefix,
|
||||
return_tensors="pt",
|
||||
n_docs=n_docs,
|
||||
)
|
||||
|
||||
context_input_ids, context_attention_mask, retrieved_doc_embeds = (
|
||||
out["context_input_ids"],
|
||||
out["context_attention_mask"],
|
||||
out["retrieved_doc_embeds"],
|
||||
)
|
||||
|
||||
# cast
|
||||
retrieved_doc_embeds = retrieved_doc_embeds.to(question_hidden_states)
|
||||
context_input_ids = context_input_ids.to(input_ids)
|
||||
context_attention_mask = context_attention_mask.to(input_ids)
|
||||
|
||||
# compute doc_scores
|
||||
doc_scores = torch.bmm(question_hidden_states.unsqueeze(1), retrieved_doc_embeds.transpose(1, 2)).squeeze(
|
||||
1
|
||||
)
|
||||
|
||||
outputs = model(
|
||||
context_input_ids=context_input_ids,
|
||||
context_attention_mask=context_attention_mask,
|
||||
doc_scores=doc_scores,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
n_docs=n_docs,
|
||||
)
|
||||
|
||||
# logits
|
||||
self.assertEqual(
|
||||
outputs.logits.shape,
|
||||
(n_docs * decoder_input_ids.shape[0], decoder_input_ids.shape[1], config.generator.vocab_size),
|
||||
)
|
||||
# generator encoder last hidden states
|
||||
self.assertEqual(
|
||||
outputs.generator_enc_last_hidden_state.shape,
|
||||
(n_docs * decoder_input_ids.shape[0], self.max_combined_length, config.generator.hidden_size),
|
||||
)
|
||||
# doc scores
|
||||
self.assertEqual(outputs.doc_scores.shape, (input_ids.shape[0], n_docs))
|
||||
|
||||
def check_model_with_mismatch_n_docs_value(
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
decoder_input_ids,
|
||||
decoder_attention_mask,
|
||||
retriever_n_docs,
|
||||
generator_n_docs,
|
||||
**kwargs
|
||||
):
|
||||
self.assertIsNotNone(config.question_encoder)
|
||||
self.assertIsNotNone(config.generator)
|
||||
|
||||
retriever = self.get_retriever(config)
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config).to(torch_device)
|
||||
model.eval()
|
||||
self.assertTrue(model.config.is_encoder_decoder)
|
||||
|
||||
question_hidden_states = model.question_encoder(input_ids, attention_mask=attention_mask)[0]
|
||||
|
||||
out = retriever(
|
||||
input_ids,
|
||||
question_hidden_states.cpu().detach().to(torch.float32).numpy(),
|
||||
prefix=config.generator.prefix,
|
||||
return_tensors="pt",
|
||||
n_docs=retriever_n_docs,
|
||||
)
|
||||
|
||||
context_input_ids, context_attention_mask, retrieved_doc_embeds = (
|
||||
out["context_input_ids"],
|
||||
out["context_attention_mask"],
|
||||
out["retrieved_doc_embeds"],
|
||||
)
|
||||
|
||||
# cast
|
||||
retrieved_doc_embeds = retrieved_doc_embeds.to(question_hidden_states)
|
||||
context_input_ids = context_input_ids.to(input_ids)
|
||||
context_attention_mask = context_attention_mask.to(input_ids)
|
||||
|
||||
# compute doc_scores
|
||||
doc_scores = torch.bmm(question_hidden_states.unsqueeze(1), retrieved_doc_embeds.transpose(1, 2)).squeeze(
|
||||
1
|
||||
)
|
||||
|
||||
self.assertRaises(
|
||||
AssertionError,
|
||||
model.__call__,
|
||||
context_input_ids=context_input_ids,
|
||||
context_attention_mask=context_attention_mask,
|
||||
doc_scores=doc_scores,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
n_docs=generator_n_docs,
|
||||
)
|
||||
|
||||
def check_model_with_encoder_outputs(
|
||||
self, config, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, **kwargs
|
||||
):
|
||||
@@ -373,6 +496,17 @@ class RagTestMixin:
|
||||
inputs_dict = self.config_and_inputs
|
||||
self.check_model_generate(**inputs_dict)
|
||||
|
||||
def test_model_with_custom_n_docs(self):
|
||||
inputs_dict = self.config_and_inputs
|
||||
inputs_dict["n_docs"] = 1
|
||||
self.check_model_custom_n_docs(**inputs_dict)
|
||||
|
||||
def test_model_with_mismatch_n_docs_value(self):
|
||||
inputs_dict = self.config_and_inputs
|
||||
inputs_dict["retriever_n_docs"] = 3
|
||||
inputs_dict["generator_n_docs"] = 2
|
||||
self.check_model_with_mismatch_n_docs_value(**inputs_dict)
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_retrieval
|
||||
|
||||
Reference in New Issue
Block a user