correct attention mask (#7373)
This commit is contained in:
committed by
GitHub
parent
a8cbc4269c
commit
0804d077c6
@@ -416,7 +416,6 @@ class RagDPRT5Test(RagTestMixin, unittest.TestCase):
|
||||
t5_config_and_inputs = generator_tester.prepare_config_and_inputs()
|
||||
|
||||
(question_encoder_config, input_ids, _, input_mask, _, _, _) = dpr_config_and_inputs
|
||||
# import ipdb; ipdb.set_trace()
|
||||
(generator_config, _, decoder_input_ids, _, decoder_attention_mask, _) = t5_config_and_inputs
|
||||
config = RagConfig.from_question_encoder_generator_configs(
|
||||
question_encoder_config,
|
||||
@@ -620,18 +619,21 @@ class RagModelIntegrationTests(unittest.TestCase):
|
||||
questions = [
|
||||
"who sings does he love me with reba",
|
||||
"how many pages is invisible man by ralph ellison",
|
||||
"what",
|
||||
]
|
||||
input_ids = rag_question_encoder_tokenizer.batch_encode_plus(
|
||||
input_dict = rag_question_encoder_tokenizer.batch_encode_plus(
|
||||
questions,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
truncation=True,
|
||||
).input_ids
|
||||
)
|
||||
|
||||
input_ids = input_ids.to(torch_device)
|
||||
input_ids = input_dict.input_ids.to(torch_device)
|
||||
attention_mask = input_dict.attention_mask.to(torch_device)
|
||||
|
||||
output_ids = rag_token.generate(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_start_token_id=rag_token.generator.config.decoder_start_token_id,
|
||||
num_beams=4,
|
||||
num_return_sequences=1,
|
||||
@@ -641,13 +643,16 @@ class RagModelIntegrationTests(unittest.TestCase):
|
||||
# sequence generate test
|
||||
output_text_1 = rag_decoder_tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
||||
output_text_2 = rag_decoder_tokenizer.decode(output_ids[1], skip_special_tokens=True)
|
||||
output_text_3 = rag_decoder_tokenizer.decode(output_ids[2], skip_special_tokens=True)
|
||||
|
||||
# Expected outputs as given by model at integration time.
|
||||
EXPECTED_OUTPUT_TEXT_1 = '"People Need Love" is the'
|
||||
EXPECTED_OUTPUT_TEXT_2 = '"How many pages is invisible man'
|
||||
EXPECTED_OUTPUT_TEXT_3 = "Otis the Aardvark"
|
||||
|
||||
self.assertEqual(output_text_1, EXPECTED_OUTPUT_TEXT_1)
|
||||
self.assertEqual(output_text_2, EXPECTED_OUTPUT_TEXT_2)
|
||||
self.assertEqual(output_text_3, EXPECTED_OUTPUT_TEXT_3)
|
||||
|
||||
@slow
|
||||
def test_rag_sequence_generate_batch(self):
|
||||
@@ -669,18 +674,22 @@ class RagModelIntegrationTests(unittest.TestCase):
|
||||
questions = [
|
||||
"who sings does he love me with reba",
|
||||
"how many pages is invisible man by ralph ellison",
|
||||
"what",
|
||||
]
|
||||
input_ids = rag_question_encoder_tokenizer.batch_encode_plus(
|
||||
|
||||
input_dict = rag_question_encoder_tokenizer.batch_encode_plus(
|
||||
questions,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
truncation=True,
|
||||
).input_ids
|
||||
)
|
||||
|
||||
input_ids = input_ids.to(torch_device)
|
||||
input_ids = input_dict.input_ids.to(torch_device)
|
||||
attention_mask = input_dict.attention_mask.to(torch_device)
|
||||
|
||||
output_ids = rag_sequence.generate(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_start_token_id=rag_sequence.generator.config.decoder_start_token_id,
|
||||
num_beams=4,
|
||||
num_return_sequences=1,
|
||||
@@ -690,13 +699,16 @@ class RagModelIntegrationTests(unittest.TestCase):
|
||||
# sequence generate test
|
||||
output_text_1 = rag_decoder_tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
||||
output_text_2 = rag_decoder_tokenizer.decode(output_ids[1], skip_special_tokens=True)
|
||||
output_text_3 = rag_decoder_tokenizer.decode(output_ids[2], skip_special_tokens=True)
|
||||
|
||||
# Expected outputs as given by model at integration time.
|
||||
EXPECTED_OUTPUT_TEXT_1 = '"I Know Him So Well"'
|
||||
EXPECTED_OUTPUT_TEXT_2 = '"Howl" chronicles the'
|
||||
EXPECTED_OUTPUT_TEXT_3 = "Otis the Aardvark"
|
||||
|
||||
self.assertEqual(output_text_1, EXPECTED_OUTPUT_TEXT_1)
|
||||
self.assertEqual(output_text_2, EXPECTED_OUTPUT_TEXT_2)
|
||||
self.assertEqual(output_text_3, EXPECTED_OUTPUT_TEXT_3)
|
||||
|
||||
@slow
|
||||
def test_rag_sequence_generate_beam(self):
|
||||
|
||||
Reference in New Issue
Block a user