correct attention mask (#7373)

This commit is contained in:
Patrick von Platen
2020-09-24 23:22:04 +02:00
committed by GitHub
parent a8cbc4269c
commit 0804d077c6
3 changed files with 30 additions and 12 deletions

View File

@@ -115,11 +115,15 @@ def evaluate_batch_retrieval(args, rag_model, questions):
def evaluate_batch_e2e(args, rag_model, questions): def evaluate_batch_e2e(args, rag_model, questions):
with torch.no_grad(): with torch.no_grad():
input_ids = rag_model.retriever.question_encoder_tokenizer.batch_encode_plus( inputs_dict = rag_model.retriever.question_encoder_tokenizer.batch_encode_plus(
questions, return_tensors="pt", padding=True, truncation=True questions, return_tensors="pt", padding=True, truncation=True
)["input_ids"].to(args.device) )
input_ids = inputs_dict.input_ids.to(args.device)
attention_mask = inputs_dict.attention_mask.to(args.device)
outputs = rag_model.generate( # rag_model overwrites generate outputs = rag_model.generate( # rag_model overwrites generate
input_ids, input_ids,
attention_mask=attention_mask,
num_beams=args.num_beams, num_beams=args.num_beams,
min_length=args.min_length, min_length=args.min_length,
max_length=args.max_length, max_length=args.max_length,

View File

@@ -814,7 +814,8 @@ class RagSequenceForGeneration(RagPreTrainedModel):
@torch.no_grad() @torch.no_grad()
def generate( def generate(
self, self,
input_ids, input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
context_input_ids=None, context_input_ids=None,
do_deduplication=None, # defaults to True do_deduplication=None, # defaults to True
num_return_sequences=None, # defaults to 1 num_return_sequences=None, # defaults to 1
@@ -859,7 +860,7 @@ class RagSequenceForGeneration(RagPreTrainedModel):
# TODO(patrick) - clean up generate here # TODO(patrick) - clean up generate here
if self.retriever is not None and context_input_ids is None: if self.retriever is not None and context_input_ids is None:
question_hidden_states = self.question_encoder(input_ids)[0] question_hidden_states = self.question_encoder(input_ids, attention_mask=attention_mask)[0]
context_input_ids = self.retriever( context_input_ids = self.retriever(
input_ids, input_ids,
question_hidden_states.cpu().detach().to(torch.float32).numpy(), question_hidden_states.cpu().detach().to(torch.float32).numpy(),
@@ -1180,6 +1181,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
def generate( def generate(
self, self,
input_ids: Optional[torch.LongTensor] = None, input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
context_input_ids=None, context_input_ids=None,
context_attention_mask=None, context_attention_mask=None,
doc_scores=None, doc_scores=None,
@@ -1293,7 +1295,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
# retrieve docs # retrieve docs
if self.retriever is not None and context_input_ids is None: if self.retriever is not None and context_input_ids is None:
question_hidden_states = self.question_encoder(input_ids)[0] question_hidden_states = self.question_encoder(input_ids, attention_mask=attention_mask)[0]
out = self.retriever( out = self.retriever(
input_ids, input_ids,
question_hidden_states.cpu().detach().to(torch.float32).numpy(), question_hidden_states.cpu().detach().to(torch.float32).numpy(),

View File

@@ -416,7 +416,6 @@ class RagDPRT5Test(RagTestMixin, unittest.TestCase):
t5_config_and_inputs = generator_tester.prepare_config_and_inputs() t5_config_and_inputs = generator_tester.prepare_config_and_inputs()
(question_encoder_config, input_ids, _, input_mask, _, _, _) = dpr_config_and_inputs (question_encoder_config, input_ids, _, input_mask, _, _, _) = dpr_config_and_inputs
# import ipdb; ipdb.set_trace()
(generator_config, _, decoder_input_ids, _, decoder_attention_mask, _) = t5_config_and_inputs (generator_config, _, decoder_input_ids, _, decoder_attention_mask, _) = t5_config_and_inputs
config = RagConfig.from_question_encoder_generator_configs( config = RagConfig.from_question_encoder_generator_configs(
question_encoder_config, question_encoder_config,
@@ -620,18 +619,21 @@ class RagModelIntegrationTests(unittest.TestCase):
questions = [ questions = [
"who sings does he love me with reba", "who sings does he love me with reba",
"how many pages is invisible man by ralph ellison", "how many pages is invisible man by ralph ellison",
"what",
] ]
input_ids = rag_question_encoder_tokenizer.batch_encode_plus( input_dict = rag_question_encoder_tokenizer.batch_encode_plus(
questions, questions,
return_tensors="pt", return_tensors="pt",
padding=True, padding=True,
truncation=True, truncation=True,
).input_ids )
input_ids = input_ids.to(torch_device) input_ids = input_dict.input_ids.to(torch_device)
attention_mask = input_dict.attention_mask.to(torch_device)
output_ids = rag_token.generate( output_ids = rag_token.generate(
input_ids, input_ids,
attention_mask=attention_mask,
decoder_start_token_id=rag_token.generator.config.decoder_start_token_id, decoder_start_token_id=rag_token.generator.config.decoder_start_token_id,
num_beams=4, num_beams=4,
num_return_sequences=1, num_return_sequences=1,
@@ -641,13 +643,16 @@ class RagModelIntegrationTests(unittest.TestCase):
# sequence generate test # sequence generate test
output_text_1 = rag_decoder_tokenizer.decode(output_ids[0], skip_special_tokens=True) output_text_1 = rag_decoder_tokenizer.decode(output_ids[0], skip_special_tokens=True)
output_text_2 = rag_decoder_tokenizer.decode(output_ids[1], skip_special_tokens=True) output_text_2 = rag_decoder_tokenizer.decode(output_ids[1], skip_special_tokens=True)
output_text_3 = rag_decoder_tokenizer.decode(output_ids[2], skip_special_tokens=True)
# Expected outputs as given by model at integration time. # Expected outputs as given by model at integration time.
EXPECTED_OUTPUT_TEXT_1 = '"People Need Love" is the' EXPECTED_OUTPUT_TEXT_1 = '"People Need Love" is the'
EXPECTED_OUTPUT_TEXT_2 = '"How many pages is invisible man' EXPECTED_OUTPUT_TEXT_2 = '"How many pages is invisible man'
EXPECTED_OUTPUT_TEXT_3 = "Otis the Aardvark"
self.assertEqual(output_text_1, EXPECTED_OUTPUT_TEXT_1) self.assertEqual(output_text_1, EXPECTED_OUTPUT_TEXT_1)
self.assertEqual(output_text_2, EXPECTED_OUTPUT_TEXT_2) self.assertEqual(output_text_2, EXPECTED_OUTPUT_TEXT_2)
self.assertEqual(output_text_3, EXPECTED_OUTPUT_TEXT_3)
@slow @slow
def test_rag_sequence_generate_batch(self): def test_rag_sequence_generate_batch(self):
@@ -669,18 +674,22 @@ class RagModelIntegrationTests(unittest.TestCase):
questions = [ questions = [
"who sings does he love me with reba", "who sings does he love me with reba",
"how many pages is invisible man by ralph ellison", "how many pages is invisible man by ralph ellison",
"what",
] ]
input_ids = rag_question_encoder_tokenizer.batch_encode_plus(
input_dict = rag_question_encoder_tokenizer.batch_encode_plus(
questions, questions,
return_tensors="pt", return_tensors="pt",
padding=True, padding=True,
truncation=True, truncation=True,
).input_ids )
input_ids = input_ids.to(torch_device) input_ids = input_dict.input_ids.to(torch_device)
attention_mask = input_dict.attention_mask.to(torch_device)
output_ids = rag_sequence.generate( output_ids = rag_sequence.generate(
input_ids, input_ids,
attention_mask=attention_mask,
decoder_start_token_id=rag_sequence.generator.config.decoder_start_token_id, decoder_start_token_id=rag_sequence.generator.config.decoder_start_token_id,
num_beams=4, num_beams=4,
num_return_sequences=1, num_return_sequences=1,
@@ -690,13 +699,16 @@ class RagModelIntegrationTests(unittest.TestCase):
# sequence generate test # sequence generate test
output_text_1 = rag_decoder_tokenizer.decode(output_ids[0], skip_special_tokens=True) output_text_1 = rag_decoder_tokenizer.decode(output_ids[0], skip_special_tokens=True)
output_text_2 = rag_decoder_tokenizer.decode(output_ids[1], skip_special_tokens=True) output_text_2 = rag_decoder_tokenizer.decode(output_ids[1], skip_special_tokens=True)
output_text_3 = rag_decoder_tokenizer.decode(output_ids[2], skip_special_tokens=True)
# Expected outputs as given by model at integration time. # Expected outputs as given by model at integration time.
EXPECTED_OUTPUT_TEXT_1 = '"I Know Him So Well"' EXPECTED_OUTPUT_TEXT_1 = '"I Know Him So Well"'
EXPECTED_OUTPUT_TEXT_2 = '"Howl" chronicles the' EXPECTED_OUTPUT_TEXT_2 = '"Howl" chronicles the'
EXPECTED_OUTPUT_TEXT_3 = "Otis the Aardvark"
self.assertEqual(output_text_1, EXPECTED_OUTPUT_TEXT_1) self.assertEqual(output_text_1, EXPECTED_OUTPUT_TEXT_1)
self.assertEqual(output_text_2, EXPECTED_OUTPUT_TEXT_2) self.assertEqual(output_text_2, EXPECTED_OUTPUT_TEXT_2)
self.assertEqual(output_text_3, EXPECTED_OUTPUT_TEXT_3)
@slow @slow
def test_rag_sequence_generate_beam(self): def test_rag_sequence_generate_beam(self):