Proposed Fix : [RagSequenceForGeneration] generate "without" input_ids (#9220)
* Create modeling_tf_dpr.py * Add TFDPR * Add back TFPegasus, TFMarian, TFMBart, TFBlenderBot last commit accidentally deleted these 4 lines, so I recover them back * Add TFDPR * Add TFDPR * clean up some comments, add TF input-style doc string * Add TFDPR * Make return_dict=False as default * Fix return_dict bug (in .from_pretrained) * Add get_input_embeddings() * Create test_modeling_tf_dpr.py The current version is already passed all 27 tests! Please see the test run at : https://colab.research.google.com/drive/1czS_m9zy5k-iSJbzA_DP1k1xAAC_sdkf?usp=sharing * fix quality * delete init weights * run fix copies * fix repo consis * del config_class, load_tf_weights They shoud be 'pytorch only' * add config_class back after removing it, test failed ... so totally only removing "use_tf_weights = None" on Lysandre suggestion * newline after .. note:: * import tf, np (Necessary for ModelIntegrationTest) * slow_test from_pretrained with from_pt=True At the moment we don't have TF weights (since we don't have official official TF model) Previously, I did not run slow test, so I missed this bug * Add simple TFDPRModelIntegrationTest Note that this is just a test that TF and Pytorch gives approx. the same output. However, I could not test with the official DPR repo's output yet * upload correct tf model * remove position_ids as missing keys * fix RagSeq generate with context_input_ids fix RagSeq generate with context_input_ids * apply style * delete unused lines * Add test_rag_sequence_generate_batch_from_context_input_ids * Readability improved * stylying * Stylize * typos * add check_model_generate_from_context_input_ids * make style * Apply suggestions from code review * make style2 Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: patrickvonplaten <patrick@huggingface.co>
This commit is contained in:
committed by
GitHub
parent
2a18b70998
commit
f3a3b91d6f
@@ -246,6 +246,53 @@ class RagTestMixin:
|
||||
# doc scores
|
||||
self.assertEqual(outputs.doc_scores.shape, (input_ids.shape[0], self.n_docs))
|
||||
|
||||
def check_model_generate_from_context_input_ids(
|
||||
self, config, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, **kwargs
|
||||
):
|
||||
self.assertIsNotNone(config.question_encoder)
|
||||
self.assertIsNotNone(config.generator)
|
||||
|
||||
retriever = self.get_retriever(config)
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config).to(torch_device)
|
||||
model.eval()
|
||||
self.assertTrue(model.config.is_encoder_decoder)
|
||||
|
||||
question_hidden_states = model.question_encoder(input_ids, attention_mask=attention_mask)[0]
|
||||
|
||||
out = retriever(
|
||||
input_ids,
|
||||
question_hidden_states.cpu().detach().to(torch.float32).numpy(),
|
||||
prefix=config.generator.prefix,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
context_input_ids, context_attention_mask, retrieved_doc_embeds = (
|
||||
out["context_input_ids"],
|
||||
out["context_attention_mask"],
|
||||
out["retrieved_doc_embeds"],
|
||||
)
|
||||
|
||||
# cast
|
||||
retrieved_doc_embeds = retrieved_doc_embeds.to(question_hidden_states)
|
||||
context_input_ids = context_input_ids.to(input_ids)
|
||||
context_attention_mask = context_attention_mask.to(input_ids)
|
||||
|
||||
# compute doc_scores
|
||||
doc_scores = torch.bmm(question_hidden_states.unsqueeze(1), retrieved_doc_embeds.transpose(1, 2)).squeeze(
|
||||
1
|
||||
)
|
||||
|
||||
outputs = model.generate(
|
||||
context_input_ids=context_input_ids,
|
||||
context_attention_mask=context_attention_mask,
|
||||
doc_scores=doc_scores,
|
||||
do_deduplication=True,
|
||||
)
|
||||
|
||||
self.assertIsNotNone(outputs)
|
||||
|
||||
def check_model_generate(
|
||||
self, config, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, **kwargs
|
||||
):
|
||||
@@ -848,6 +895,63 @@ class RagModelIntegrationTests(unittest.TestCase):
|
||||
]
|
||||
self.assertListEqual(outputs, EXPECTED_OUTPUTS)
|
||||
|
||||
@slow
|
||||
def test_rag_sequence_generate_batch_from_context_input_ids(self):
|
||||
tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq")
|
||||
retriever = RagRetriever.from_pretrained(
|
||||
"facebook/rag-sequence-nq", index_name="exact", use_dummy_dataset=True
|
||||
)
|
||||
rag_sequence = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq", retriever=retriever).to(
|
||||
torch_device
|
||||
)
|
||||
|
||||
input_dict = tokenizer(
|
||||
self.test_data_questions,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
truncation=True,
|
||||
)
|
||||
|
||||
input_ids = input_dict.input_ids.to(torch_device)
|
||||
attention_mask = input_dict.attention_mask.to(torch_device)
|
||||
|
||||
question_hidden_states = rag_sequence.question_encoder(input_ids, attention_mask=attention_mask)[0]
|
||||
docs_dict = retriever(
|
||||
input_ids.cpu().detach().numpy(), question_hidden_states.cpu().detach().numpy(), return_tensors="pt"
|
||||
)
|
||||
doc_scores = torch.bmm(
|
||||
question_hidden_states.unsqueeze(1),
|
||||
docs_dict["retrieved_doc_embeds"].to(torch_device).float().transpose(1, 2),
|
||||
).squeeze(1)
|
||||
|
||||
output_ids = rag_sequence.generate(
|
||||
context_input_ids=docs_dict["context_input_ids"].to(torch_device),
|
||||
context_attention_mask=docs_dict["context_attention_mask"].to(torch_device),
|
||||
doc_scores=doc_scores.to(torch_device),
|
||||
do_deduplication=True,
|
||||
)
|
||||
|
||||
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
||||
|
||||
EXPECTED_OUTPUTS = [
|
||||
" albert einstein",
|
||||
" june 22, 2018",
|
||||
" amplitude modulation",
|
||||
" tim besley ( chairman )",
|
||||
" june 20, 2018",
|
||||
" 1980",
|
||||
" 7.0",
|
||||
" 8",
|
||||
" reticular formation",
|
||||
" walls of the abdomen",
|
||||
" spodumene",
|
||||
" obama",
|
||||
" new orleans",
|
||||
" japan",
|
||||
" old trafford",
|
||||
]
|
||||
self.assertListEqual(outputs, EXPECTED_OUTPUTS)
|
||||
|
||||
@slow
|
||||
def test_rag_token_generate_batch(self):
|
||||
tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")
|
||||
|
||||
Reference in New Issue
Block a user