correct attention mask (#7373)

This commit is contained in:
Patrick von Platen
2020-09-24 23:22:04 +02:00
committed by GitHub
parent a8cbc4269c
commit 0804d077c6
3 changed files with 30 additions and 12 deletions

View File

@@ -416,7 +416,6 @@ class RagDPRT5Test(RagTestMixin, unittest.TestCase):
t5_config_and_inputs = generator_tester.prepare_config_and_inputs()
(question_encoder_config, input_ids, _, input_mask, _, _, _) = dpr_config_and_inputs
# import ipdb; ipdb.set_trace()
(generator_config, _, decoder_input_ids, _, decoder_attention_mask, _) = t5_config_and_inputs
config = RagConfig.from_question_encoder_generator_configs(
question_encoder_config,
@@ -620,18 +619,21 @@ class RagModelIntegrationTests(unittest.TestCase):
questions = [
"who sings does he love me with reba",
"how many pages is invisible man by ralph ellison",
"what",
]
input_ids = rag_question_encoder_tokenizer.batch_encode_plus(
input_dict = rag_question_encoder_tokenizer.batch_encode_plus(
questions,
return_tensors="pt",
padding=True,
truncation=True,
).input_ids
)
input_ids = input_ids.to(torch_device)
input_ids = input_dict.input_ids.to(torch_device)
attention_mask = input_dict.attention_mask.to(torch_device)
output_ids = rag_token.generate(
input_ids,
attention_mask=attention_mask,
decoder_start_token_id=rag_token.generator.config.decoder_start_token_id,
num_beams=4,
num_return_sequences=1,
@@ -641,13 +643,16 @@ class RagModelIntegrationTests(unittest.TestCase):
# sequence generate test
output_text_1 = rag_decoder_tokenizer.decode(output_ids[0], skip_special_tokens=True)
output_text_2 = rag_decoder_tokenizer.decode(output_ids[1], skip_special_tokens=True)
output_text_3 = rag_decoder_tokenizer.decode(output_ids[2], skip_special_tokens=True)
# Expected outputs as given by model at integration time.
EXPECTED_OUTPUT_TEXT_1 = '"People Need Love" is the'
EXPECTED_OUTPUT_TEXT_2 = '"How many pages is invisible man'
EXPECTED_OUTPUT_TEXT_3 = "Otis the Aardvark"
self.assertEqual(output_text_1, EXPECTED_OUTPUT_TEXT_1)
self.assertEqual(output_text_2, EXPECTED_OUTPUT_TEXT_2)
self.assertEqual(output_text_3, EXPECTED_OUTPUT_TEXT_3)
@slow
def test_rag_sequence_generate_batch(self):
@@ -669,18 +674,22 @@ class RagModelIntegrationTests(unittest.TestCase):
questions = [
"who sings does he love me with reba",
"how many pages is invisible man by ralph ellison",
"what",
]
input_ids = rag_question_encoder_tokenizer.batch_encode_plus(
input_dict = rag_question_encoder_tokenizer.batch_encode_plus(
questions,
return_tensors="pt",
padding=True,
truncation=True,
).input_ids
)
input_ids = input_ids.to(torch_device)
input_ids = input_dict.input_ids.to(torch_device)
attention_mask = input_dict.attention_mask.to(torch_device)
output_ids = rag_sequence.generate(
input_ids,
attention_mask=attention_mask,
decoder_start_token_id=rag_sequence.generator.config.decoder_start_token_id,
num_beams=4,
num_return_sequences=1,
@@ -690,13 +699,16 @@ class RagModelIntegrationTests(unittest.TestCase):
# sequence generate test
output_text_1 = rag_decoder_tokenizer.decode(output_ids[0], skip_special_tokens=True)
output_text_2 = rag_decoder_tokenizer.decode(output_ids[1], skip_special_tokens=True)
output_text_3 = rag_decoder_tokenizer.decode(output_ids[2], skip_special_tokens=True)
# Expected outputs as given by model at integration time.
EXPECTED_OUTPUT_TEXT_1 = '"I Know Him So Well"'
EXPECTED_OUTPUT_TEXT_2 = '"Howl" chronicles the'
EXPECTED_OUTPUT_TEXT_3 = "Otis the Aardvark"
self.assertEqual(output_text_1, EXPECTED_OUTPUT_TEXT_1)
self.assertEqual(output_text_2, EXPECTED_OUTPUT_TEXT_2)
self.assertEqual(output_text_3, EXPECTED_OUTPUT_TEXT_3)
@slow
def test_rag_sequence_generate_beam(self):