From 0804d077c634b2149b833ecc7897959cab8bf650 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 24 Sep 2020 23:22:04 +0200 Subject: [PATCH] correct attention mask (#7373) --- examples/rag/eval_rag.py | 8 ++++++-- src/transformers/modeling_rag.py | 8 +++++--- tests/test_modeling_rag.py | 26 +++++++++++++++++++------- 3 files changed, 30 insertions(+), 12 deletions(-) diff --git a/examples/rag/eval_rag.py b/examples/rag/eval_rag.py index a2fc936489..452baf7cb6 100644 --- a/examples/rag/eval_rag.py +++ b/examples/rag/eval_rag.py @@ -115,11 +115,15 @@ def evaluate_batch_retrieval(args, rag_model, questions): def evaluate_batch_e2e(args, rag_model, questions): with torch.no_grad(): - input_ids = rag_model.retriever.question_encoder_tokenizer.batch_encode_plus( + inputs_dict = rag_model.retriever.question_encoder_tokenizer.batch_encode_plus( questions, return_tensors="pt", padding=True, truncation=True - )["input_ids"].to(args.device) + ) + + input_ids = inputs_dict.input_ids.to(args.device) + attention_mask = inputs_dict.attention_mask.to(args.device) outputs = rag_model.generate( # rag_model overwrites generate input_ids, + attention_mask=attention_mask, num_beams=args.num_beams, min_length=args.min_length, max_length=args.max_length, diff --git a/src/transformers/modeling_rag.py b/src/transformers/modeling_rag.py index 7d30ad9f5d..09fe472dd8 100644 --- a/src/transformers/modeling_rag.py +++ b/src/transformers/modeling_rag.py @@ -814,7 +814,8 @@ class RagSequenceForGeneration(RagPreTrainedModel): @torch.no_grad() def generate( self, - input_ids, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, context_input_ids=None, do_deduplication=None, # defaults to True num_return_sequences=None, # defaults to 1 @@ -859,7 +860,7 @@ class RagSequenceForGeneration(RagPreTrainedModel): # TODO(patrick) - clean up generate here if self.retriever is not None and context_input_ids is None: - question_hidden_states = self.question_encoder(input_ids)[0] + question_hidden_states = self.question_encoder(input_ids, attention_mask=attention_mask)[0] context_input_ids = self.retriever( input_ids, question_hidden_states.cpu().detach().to(torch.float32).numpy(), @@ -1180,6 +1181,7 @@ class RagTokenForGeneration(RagPreTrainedModel): def generate( self, input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, context_input_ids=None, context_attention_mask=None, doc_scores=None, @@ -1293,7 +1295,7 @@ class RagTokenForGeneration(RagPreTrainedModel): # retrieve docs if self.retriever is not None and context_input_ids is None: - question_hidden_states = self.question_encoder(input_ids)[0] + question_hidden_states = self.question_encoder(input_ids, attention_mask=attention_mask)[0] out = self.retriever( input_ids, question_hidden_states.cpu().detach().to(torch.float32).numpy(), diff --git a/tests/test_modeling_rag.py b/tests/test_modeling_rag.py index 0e44420315..dfc9ee65eb 100644 --- a/tests/test_modeling_rag.py +++ b/tests/test_modeling_rag.py @@ -416,7 +416,6 @@ class RagDPRT5Test(RagTestMixin, unittest.TestCase): t5_config_and_inputs = generator_tester.prepare_config_and_inputs() (question_encoder_config, input_ids, _, input_mask, _, _, _) = dpr_config_and_inputs - # import ipdb; ipdb.set_trace() (generator_config, _, decoder_input_ids, _, decoder_attention_mask, _) = t5_config_and_inputs config = RagConfig.from_question_encoder_generator_configs( question_encoder_config, @@ -620,18 +619,21 @@ class RagModelIntegrationTests(unittest.TestCase): questions = [ "who sings does he love me with reba", "how many pages is invisible man by ralph ellison", + "what", ] - input_ids = rag_question_encoder_tokenizer.batch_encode_plus( + input_dict = rag_question_encoder_tokenizer.batch_encode_plus( questions, return_tensors="pt", padding=True, truncation=True, - ).input_ids + ) - input_ids = input_ids.to(torch_device) + input_ids = input_dict.input_ids.to(torch_device) + attention_mask = input_dict.attention_mask.to(torch_device) output_ids = rag_token.generate( input_ids, + attention_mask=attention_mask, decoder_start_token_id=rag_token.generator.config.decoder_start_token_id, num_beams=4, num_return_sequences=1, @@ -641,13 +643,16 @@ class RagModelIntegrationTests(unittest.TestCase): # sequence generate test output_text_1 = rag_decoder_tokenizer.decode(output_ids[0], skip_special_tokens=True) output_text_2 = rag_decoder_tokenizer.decode(output_ids[1], skip_special_tokens=True) + output_text_3 = rag_decoder_tokenizer.decode(output_ids[2], skip_special_tokens=True) # Expected outputs as given by model at integration time. EXPECTED_OUTPUT_TEXT_1 = '"People Need Love" is the' EXPECTED_OUTPUT_TEXT_2 = '"How many pages is invisible man' + EXPECTED_OUTPUT_TEXT_3 = "Otis the Aardvark" self.assertEqual(output_text_1, EXPECTED_OUTPUT_TEXT_1) self.assertEqual(output_text_2, EXPECTED_OUTPUT_TEXT_2) + self.assertEqual(output_text_3, EXPECTED_OUTPUT_TEXT_3) @slow def test_rag_sequence_generate_batch(self): @@ -669,18 +674,22 @@ class RagModelIntegrationTests(unittest.TestCase): questions = [ "who sings does he love me with reba", "how many pages is invisible man by ralph ellison", + "what", ] - input_ids = rag_question_encoder_tokenizer.batch_encode_plus( + + input_dict = rag_question_encoder_tokenizer.batch_encode_plus( questions, return_tensors="pt", padding=True, truncation=True, - ).input_ids + ) - input_ids = input_ids.to(torch_device) + input_ids = input_dict.input_ids.to(torch_device) + attention_mask = input_dict.attention_mask.to(torch_device) output_ids = rag_sequence.generate( input_ids, + attention_mask=attention_mask, decoder_start_token_id=rag_sequence.generator.config.decoder_start_token_id, num_beams=4, num_return_sequences=1, @@ -690,13 +699,16 @@ class RagModelIntegrationTests(unittest.TestCase): # sequence generate test output_text_1 = rag_decoder_tokenizer.decode(output_ids[0], skip_special_tokens=True) output_text_2 = rag_decoder_tokenizer.decode(output_ids[1], skip_special_tokens=True) + output_text_3 = rag_decoder_tokenizer.decode(output_ids[2], skip_special_tokens=True) # Expected outputs as given by model at integration time. EXPECTED_OUTPUT_TEXT_1 = '"I Know Him So Well"' EXPECTED_OUTPUT_TEXT_2 = '"Howl" chronicles the' + EXPECTED_OUTPUT_TEXT_3 = "Otis the Aardvark" self.assertEqual(output_text_1, EXPECTED_OUTPUT_TEXT_1) self.assertEqual(output_text_2, EXPECTED_OUTPUT_TEXT_2) + self.assertEqual(output_text_3, EXPECTED_OUTPUT_TEXT_3) @slow def test_rag_sequence_generate_beam(self):