correct attention mask (#7373)
This commit is contained in:
committed by
GitHub
parent
a8cbc4269c
commit
0804d077c6
@@ -115,11 +115,15 @@ def evaluate_batch_retrieval(args, rag_model, questions):
|
||||
|
||||
def evaluate_batch_e2e(args, rag_model, questions):
|
||||
with torch.no_grad():
|
||||
input_ids = rag_model.retriever.question_encoder_tokenizer.batch_encode_plus(
|
||||
inputs_dict = rag_model.retriever.question_encoder_tokenizer.batch_encode_plus(
|
||||
questions, return_tensors="pt", padding=True, truncation=True
|
||||
)["input_ids"].to(args.device)
|
||||
)
|
||||
|
||||
input_ids = inputs_dict.input_ids.to(args.device)
|
||||
attention_mask = inputs_dict.attention_mask.to(args.device)
|
||||
outputs = rag_model.generate( # rag_model overwrites generate
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
num_beams=args.num_beams,
|
||||
min_length=args.min_length,
|
||||
max_length=args.max_length,
|
||||
|
||||
@@ -814,7 +814,8 @@ class RagSequenceForGeneration(RagPreTrainedModel):
|
||||
@torch.no_grad()
|
||||
def generate(
|
||||
self,
|
||||
input_ids,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
context_input_ids=None,
|
||||
do_deduplication=None, # defaults to True
|
||||
num_return_sequences=None, # defaults to 1
|
||||
@@ -859,7 +860,7 @@ class RagSequenceForGeneration(RagPreTrainedModel):
|
||||
|
||||
# TODO(patrick) - clean up generate here
|
||||
if self.retriever is not None and context_input_ids is None:
|
||||
question_hidden_states = self.question_encoder(input_ids)[0]
|
||||
question_hidden_states = self.question_encoder(input_ids, attention_mask=attention_mask)[0]
|
||||
context_input_ids = self.retriever(
|
||||
input_ids,
|
||||
question_hidden_states.cpu().detach().to(torch.float32).numpy(),
|
||||
@@ -1180,6 +1181,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
||||
def generate(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
context_input_ids=None,
|
||||
context_attention_mask=None,
|
||||
doc_scores=None,
|
||||
@@ -1293,7 +1295,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
||||
|
||||
# retrieve docs
|
||||
if self.retriever is not None and context_input_ids is None:
|
||||
question_hidden_states = self.question_encoder(input_ids)[0]
|
||||
question_hidden_states = self.question_encoder(input_ids, attention_mask=attention_mask)[0]
|
||||
out = self.retriever(
|
||||
input_ids,
|
||||
question_hidden_states.cpu().detach().to(torch.float32).numpy(),
|
||||
|
||||
@@ -416,7 +416,6 @@ class RagDPRT5Test(RagTestMixin, unittest.TestCase):
|
||||
t5_config_and_inputs = generator_tester.prepare_config_and_inputs()
|
||||
|
||||
(question_encoder_config, input_ids, _, input_mask, _, _, _) = dpr_config_and_inputs
|
||||
# import ipdb; ipdb.set_trace()
|
||||
(generator_config, _, decoder_input_ids, _, decoder_attention_mask, _) = t5_config_and_inputs
|
||||
config = RagConfig.from_question_encoder_generator_configs(
|
||||
question_encoder_config,
|
||||
@@ -620,18 +619,21 @@ class RagModelIntegrationTests(unittest.TestCase):
|
||||
questions = [
|
||||
"who sings does he love me with reba",
|
||||
"how many pages is invisible man by ralph ellison",
|
||||
"what",
|
||||
]
|
||||
input_ids = rag_question_encoder_tokenizer.batch_encode_plus(
|
||||
input_dict = rag_question_encoder_tokenizer.batch_encode_plus(
|
||||
questions,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
truncation=True,
|
||||
).input_ids
|
||||
)
|
||||
|
||||
input_ids = input_ids.to(torch_device)
|
||||
input_ids = input_dict.input_ids.to(torch_device)
|
||||
attention_mask = input_dict.attention_mask.to(torch_device)
|
||||
|
||||
output_ids = rag_token.generate(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_start_token_id=rag_token.generator.config.decoder_start_token_id,
|
||||
num_beams=4,
|
||||
num_return_sequences=1,
|
||||
@@ -641,13 +643,16 @@ class RagModelIntegrationTests(unittest.TestCase):
|
||||
# sequence generate test
|
||||
output_text_1 = rag_decoder_tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
||||
output_text_2 = rag_decoder_tokenizer.decode(output_ids[1], skip_special_tokens=True)
|
||||
output_text_3 = rag_decoder_tokenizer.decode(output_ids[2], skip_special_tokens=True)
|
||||
|
||||
# Expected outputs as given by model at integration time.
|
||||
EXPECTED_OUTPUT_TEXT_1 = '"People Need Love" is the'
|
||||
EXPECTED_OUTPUT_TEXT_2 = '"How many pages is invisible man'
|
||||
EXPECTED_OUTPUT_TEXT_3 = "Otis the Aardvark"
|
||||
|
||||
self.assertEqual(output_text_1, EXPECTED_OUTPUT_TEXT_1)
|
||||
self.assertEqual(output_text_2, EXPECTED_OUTPUT_TEXT_2)
|
||||
self.assertEqual(output_text_3, EXPECTED_OUTPUT_TEXT_3)
|
||||
|
||||
@slow
|
||||
def test_rag_sequence_generate_batch(self):
|
||||
@@ -669,18 +674,22 @@ class RagModelIntegrationTests(unittest.TestCase):
|
||||
questions = [
|
||||
"who sings does he love me with reba",
|
||||
"how many pages is invisible man by ralph ellison",
|
||||
"what",
|
||||
]
|
||||
input_ids = rag_question_encoder_tokenizer.batch_encode_plus(
|
||||
|
||||
input_dict = rag_question_encoder_tokenizer.batch_encode_plus(
|
||||
questions,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
truncation=True,
|
||||
).input_ids
|
||||
)
|
||||
|
||||
input_ids = input_ids.to(torch_device)
|
||||
input_ids = input_dict.input_ids.to(torch_device)
|
||||
attention_mask = input_dict.attention_mask.to(torch_device)
|
||||
|
||||
output_ids = rag_sequence.generate(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_start_token_id=rag_sequence.generator.config.decoder_start_token_id,
|
||||
num_beams=4,
|
||||
num_return_sequences=1,
|
||||
@@ -690,13 +699,16 @@ class RagModelIntegrationTests(unittest.TestCase):
|
||||
# sequence generate test
|
||||
output_text_1 = rag_decoder_tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
||||
output_text_2 = rag_decoder_tokenizer.decode(output_ids[1], skip_special_tokens=True)
|
||||
output_text_3 = rag_decoder_tokenizer.decode(output_ids[2], skip_special_tokens=True)
|
||||
|
||||
# Expected outputs as given by model at integration time.
|
||||
EXPECTED_OUTPUT_TEXT_1 = '"I Know Him So Well"'
|
||||
EXPECTED_OUTPUT_TEXT_2 = '"Howl" chronicles the'
|
||||
EXPECTED_OUTPUT_TEXT_3 = "Otis the Aardvark"
|
||||
|
||||
self.assertEqual(output_text_1, EXPECTED_OUTPUT_TEXT_1)
|
||||
self.assertEqual(output_text_2, EXPECTED_OUTPUT_TEXT_2)
|
||||
self.assertEqual(output_text_3, EXPECTED_OUTPUT_TEXT_3)
|
||||
|
||||
@slow
|
||||
def test_rag_sequence_generate_beam(self):
|
||||
|
||||
Reference in New Issue
Block a user