From 95119ad7b0369d76925cea12ffb29d3c31014570 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 17 Dec 2021 16:08:54 +0100 Subject: [PATCH] [Generate] Correct input_ids detection (#14815) * [Generate] Correct input_ids detection * correct --- src/transformers/generation_utils.py | 2 +- tests/test_generation_utils.py | 25 +++++++++++++++++++++++++ 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index e0f8bd1651..8d639542af 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -457,7 +457,7 @@ class GenerationMixin: pad_token_id: int, eos_token_id: int, ) -> torch.LongTensor: - is_input_ids = isinstance(inputs, torch.LongTensor) and len(inputs.shape) == 2 + is_input_ids = len(inputs.shape) == 2 and inputs.dtype in [torch.int, torch.long] is_pad_token_in_inputs = (pad_token_id is not None) and (pad_token_id in inputs) is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or ( (eos_token_id is not None) and (pad_token_id != eos_token_id) diff --git a/tests/test_generation_utils.py b/tests/test_generation_utils.py index edd6e4533c..2a72840d2a 100644 --- a/tests/test_generation_utils.py +++ b/tests/test_generation_utils.py @@ -1719,6 +1719,31 @@ class GenerationIntegrationTests(unittest.TestCase): # make sure model generated correctly until `max_length` self.assertEqual(output_sequences.shape, (1, 5)) + def test_encoder_decoder_generate_attention_mask(self): + articles = ["Timberlake", "Jessica Biel, welcome to parenthood among other things"] + tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart") + # need extrem generation values here to force this test + # to fail when `attention_mask` is not correctly treated in generate + model = BartForConditionalGeneration.from_pretrained( + "hf-internal-testing/tiny-random-bart", max_length=50, num_beams=5, num_return_sequences=5 + ).to(torch_device) + + model.config.eos_token_id = None + input_ids = tokenizer(articles[0], return_tensors="pt").input_ids.to(torch_device) + input_ids_batched = tokenizer(articles, padding=True, return_tensors="pt").input_ids.to(torch_device) + + output_sequences_batched = model.generate( + input_ids=input_ids_batched, return_dict_in_generate=True, output_scores=True + ) + output_sequences = model.generate(input_ids=input_ids, return_dict_in_generate=True, output_scores=True) + + batched_out = output_sequences_batched.sequences_scores + out = output_sequences.sequences_scores + + diff = (batched_out[:5].sum() - out.sum()).abs() + + self.assertTrue(diff < 1e-4) + def test_decoder_generate_with_inputs_embeds(self): article = """I need input_ids to generate""" tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")