correct attention mask (#7373)
This commit is contained in:
committed by
GitHub
parent
a8cbc4269c
commit
0804d077c6
@@ -115,11 +115,15 @@ def evaluate_batch_retrieval(args, rag_model, questions):
|
|||||||
|
|
||||||
def evaluate_batch_e2e(args, rag_model, questions):
|
def evaluate_batch_e2e(args, rag_model, questions):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
input_ids = rag_model.retriever.question_encoder_tokenizer.batch_encode_plus(
|
inputs_dict = rag_model.retriever.question_encoder_tokenizer.batch_encode_plus(
|
||||||
questions, return_tensors="pt", padding=True, truncation=True
|
questions, return_tensors="pt", padding=True, truncation=True
|
||||||
)["input_ids"].to(args.device)
|
)
|
||||||
|
|
||||||
|
input_ids = inputs_dict.input_ids.to(args.device)
|
||||||
|
attention_mask = inputs_dict.attention_mask.to(args.device)
|
||||||
outputs = rag_model.generate( # rag_model overwrites generate
|
outputs = rag_model.generate( # rag_model overwrites generate
|
||||||
input_ids,
|
input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
num_beams=args.num_beams,
|
num_beams=args.num_beams,
|
||||||
min_length=args.min_length,
|
min_length=args.min_length,
|
||||||
max_length=args.max_length,
|
max_length=args.max_length,
|
||||||
|
|||||||
@@ -814,7 +814,8 @@ class RagSequenceForGeneration(RagPreTrainedModel):
|
|||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
input_ids,
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
|
attention_mask: Optional[torch.LongTensor] = None,
|
||||||
context_input_ids=None,
|
context_input_ids=None,
|
||||||
do_deduplication=None, # defaults to True
|
do_deduplication=None, # defaults to True
|
||||||
num_return_sequences=None, # defaults to 1
|
num_return_sequences=None, # defaults to 1
|
||||||
@@ -859,7 +860,7 @@ class RagSequenceForGeneration(RagPreTrainedModel):
|
|||||||
|
|
||||||
# TODO(patrick) - clean up generate here
|
# TODO(patrick) - clean up generate here
|
||||||
if self.retriever is not None and context_input_ids is None:
|
if self.retriever is not None and context_input_ids is None:
|
||||||
question_hidden_states = self.question_encoder(input_ids)[0]
|
question_hidden_states = self.question_encoder(input_ids, attention_mask=attention_mask)[0]
|
||||||
context_input_ids = self.retriever(
|
context_input_ids = self.retriever(
|
||||||
input_ids,
|
input_ids,
|
||||||
question_hidden_states.cpu().detach().to(torch.float32).numpy(),
|
question_hidden_states.cpu().detach().to(torch.float32).numpy(),
|
||||||
@@ -1180,6 +1181,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
|||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
input_ids: Optional[torch.LongTensor] = None,
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
|
attention_mask: Optional[torch.LongTensor] = None,
|
||||||
context_input_ids=None,
|
context_input_ids=None,
|
||||||
context_attention_mask=None,
|
context_attention_mask=None,
|
||||||
doc_scores=None,
|
doc_scores=None,
|
||||||
@@ -1293,7 +1295,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
|||||||
|
|
||||||
# retrieve docs
|
# retrieve docs
|
||||||
if self.retriever is not None and context_input_ids is None:
|
if self.retriever is not None and context_input_ids is None:
|
||||||
question_hidden_states = self.question_encoder(input_ids)[0]
|
question_hidden_states = self.question_encoder(input_ids, attention_mask=attention_mask)[0]
|
||||||
out = self.retriever(
|
out = self.retriever(
|
||||||
input_ids,
|
input_ids,
|
||||||
question_hidden_states.cpu().detach().to(torch.float32).numpy(),
|
question_hidden_states.cpu().detach().to(torch.float32).numpy(),
|
||||||
|
|||||||
@@ -416,7 +416,6 @@ class RagDPRT5Test(RagTestMixin, unittest.TestCase):
|
|||||||
t5_config_and_inputs = generator_tester.prepare_config_and_inputs()
|
t5_config_and_inputs = generator_tester.prepare_config_and_inputs()
|
||||||
|
|
||||||
(question_encoder_config, input_ids, _, input_mask, _, _, _) = dpr_config_and_inputs
|
(question_encoder_config, input_ids, _, input_mask, _, _, _) = dpr_config_and_inputs
|
||||||
# import ipdb; ipdb.set_trace()
|
|
||||||
(generator_config, _, decoder_input_ids, _, decoder_attention_mask, _) = t5_config_and_inputs
|
(generator_config, _, decoder_input_ids, _, decoder_attention_mask, _) = t5_config_and_inputs
|
||||||
config = RagConfig.from_question_encoder_generator_configs(
|
config = RagConfig.from_question_encoder_generator_configs(
|
||||||
question_encoder_config,
|
question_encoder_config,
|
||||||
@@ -620,18 +619,21 @@ class RagModelIntegrationTests(unittest.TestCase):
|
|||||||
questions = [
|
questions = [
|
||||||
"who sings does he love me with reba",
|
"who sings does he love me with reba",
|
||||||
"how many pages is invisible man by ralph ellison",
|
"how many pages is invisible man by ralph ellison",
|
||||||
|
"what",
|
||||||
]
|
]
|
||||||
input_ids = rag_question_encoder_tokenizer.batch_encode_plus(
|
input_dict = rag_question_encoder_tokenizer.batch_encode_plus(
|
||||||
questions,
|
questions,
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
padding=True,
|
padding=True,
|
||||||
truncation=True,
|
truncation=True,
|
||||||
).input_ids
|
)
|
||||||
|
|
||||||
input_ids = input_ids.to(torch_device)
|
input_ids = input_dict.input_ids.to(torch_device)
|
||||||
|
attention_mask = input_dict.attention_mask.to(torch_device)
|
||||||
|
|
||||||
output_ids = rag_token.generate(
|
output_ids = rag_token.generate(
|
||||||
input_ids,
|
input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
decoder_start_token_id=rag_token.generator.config.decoder_start_token_id,
|
decoder_start_token_id=rag_token.generator.config.decoder_start_token_id,
|
||||||
num_beams=4,
|
num_beams=4,
|
||||||
num_return_sequences=1,
|
num_return_sequences=1,
|
||||||
@@ -641,13 +643,16 @@ class RagModelIntegrationTests(unittest.TestCase):
|
|||||||
# sequence generate test
|
# sequence generate test
|
||||||
output_text_1 = rag_decoder_tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
output_text_1 = rag_decoder_tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
||||||
output_text_2 = rag_decoder_tokenizer.decode(output_ids[1], skip_special_tokens=True)
|
output_text_2 = rag_decoder_tokenizer.decode(output_ids[1], skip_special_tokens=True)
|
||||||
|
output_text_3 = rag_decoder_tokenizer.decode(output_ids[2], skip_special_tokens=True)
|
||||||
|
|
||||||
# Expected outputs as given by model at integration time.
|
# Expected outputs as given by model at integration time.
|
||||||
EXPECTED_OUTPUT_TEXT_1 = '"People Need Love" is the'
|
EXPECTED_OUTPUT_TEXT_1 = '"People Need Love" is the'
|
||||||
EXPECTED_OUTPUT_TEXT_2 = '"How many pages is invisible man'
|
EXPECTED_OUTPUT_TEXT_2 = '"How many pages is invisible man'
|
||||||
|
EXPECTED_OUTPUT_TEXT_3 = "Otis the Aardvark"
|
||||||
|
|
||||||
self.assertEqual(output_text_1, EXPECTED_OUTPUT_TEXT_1)
|
self.assertEqual(output_text_1, EXPECTED_OUTPUT_TEXT_1)
|
||||||
self.assertEqual(output_text_2, EXPECTED_OUTPUT_TEXT_2)
|
self.assertEqual(output_text_2, EXPECTED_OUTPUT_TEXT_2)
|
||||||
|
self.assertEqual(output_text_3, EXPECTED_OUTPUT_TEXT_3)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_rag_sequence_generate_batch(self):
|
def test_rag_sequence_generate_batch(self):
|
||||||
@@ -669,18 +674,22 @@ class RagModelIntegrationTests(unittest.TestCase):
|
|||||||
questions = [
|
questions = [
|
||||||
"who sings does he love me with reba",
|
"who sings does he love me with reba",
|
||||||
"how many pages is invisible man by ralph ellison",
|
"how many pages is invisible man by ralph ellison",
|
||||||
|
"what",
|
||||||
]
|
]
|
||||||
input_ids = rag_question_encoder_tokenizer.batch_encode_plus(
|
|
||||||
|
input_dict = rag_question_encoder_tokenizer.batch_encode_plus(
|
||||||
questions,
|
questions,
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
padding=True,
|
padding=True,
|
||||||
truncation=True,
|
truncation=True,
|
||||||
).input_ids
|
)
|
||||||
|
|
||||||
input_ids = input_ids.to(torch_device)
|
input_ids = input_dict.input_ids.to(torch_device)
|
||||||
|
attention_mask = input_dict.attention_mask.to(torch_device)
|
||||||
|
|
||||||
output_ids = rag_sequence.generate(
|
output_ids = rag_sequence.generate(
|
||||||
input_ids,
|
input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
decoder_start_token_id=rag_sequence.generator.config.decoder_start_token_id,
|
decoder_start_token_id=rag_sequence.generator.config.decoder_start_token_id,
|
||||||
num_beams=4,
|
num_beams=4,
|
||||||
num_return_sequences=1,
|
num_return_sequences=1,
|
||||||
@@ -690,13 +699,16 @@ class RagModelIntegrationTests(unittest.TestCase):
|
|||||||
# sequence generate test
|
# sequence generate test
|
||||||
output_text_1 = rag_decoder_tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
output_text_1 = rag_decoder_tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
||||||
output_text_2 = rag_decoder_tokenizer.decode(output_ids[1], skip_special_tokens=True)
|
output_text_2 = rag_decoder_tokenizer.decode(output_ids[1], skip_special_tokens=True)
|
||||||
|
output_text_3 = rag_decoder_tokenizer.decode(output_ids[2], skip_special_tokens=True)
|
||||||
|
|
||||||
# Expected outputs as given by model at integration time.
|
# Expected outputs as given by model at integration time.
|
||||||
EXPECTED_OUTPUT_TEXT_1 = '"I Know Him So Well"'
|
EXPECTED_OUTPUT_TEXT_1 = '"I Know Him So Well"'
|
||||||
EXPECTED_OUTPUT_TEXT_2 = '"Howl" chronicles the'
|
EXPECTED_OUTPUT_TEXT_2 = '"Howl" chronicles the'
|
||||||
|
EXPECTED_OUTPUT_TEXT_3 = "Otis the Aardvark"
|
||||||
|
|
||||||
self.assertEqual(output_text_1, EXPECTED_OUTPUT_TEXT_1)
|
self.assertEqual(output_text_1, EXPECTED_OUTPUT_TEXT_1)
|
||||||
self.assertEqual(output_text_2, EXPECTED_OUTPUT_TEXT_2)
|
self.assertEqual(output_text_2, EXPECTED_OUTPUT_TEXT_2)
|
||||||
|
self.assertEqual(output_text_3, EXPECTED_OUTPUT_TEXT_3)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_rag_sequence_generate_beam(self):
|
def test_rag_sequence_generate_beam(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user