From f25a9332e8d091398ce96c462e02a467943c8eb9 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 19 Nov 2021 15:35:06 +0100 Subject: [PATCH] [Generation] Allow `inputs_embeds` as an input (#14443) * up * finalize * finalize * finish * Update src/transformers/generation_utils.py * apply feedback --- src/transformers/generation_utils.py | 32 +++++++--- tests/test_generation_utils.py | 92 ++++++++++++++++++++-------- 2 files changed, 92 insertions(+), 32 deletions(-) diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index 1844d688a6..066fab894b 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -392,15 +392,26 @@ class GenerationMixin: return torch.ones((1, 1), dtype=torch.long, device=self.device) * bos_token_id def _prepare_attention_mask_for_generation( - self, input_ids: torch.Tensor, pad_token_id: int, eos_token_id: int + self, + input_ids: torch.Tensor, + pad_token_id: int, + eos_token_id: int, + inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.LongTensor: + + # First if `inputs_embeds` are given, but no `attention_mask` assume that full attention_mask is used + if inputs_embeds is not None: + return torch.ones((inputs_embeds.shape[0], inputs_embeds.shape[1]), dtype=torch.long) + + # Otherwise, use `input_ids` is_pad_token_in_inputs_ids = (pad_token_id is not None) and (pad_token_id in input_ids) is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or ( (eos_token_id is not None) and (pad_token_id != eos_token_id) ) if is_pad_token_in_inputs_ids and is_pad_token_not_equal_to_eos_token_id: return input_ids.ne(pad_token_id).long() - return input_ids.new_ones(input_ids.shape, dtype=torch.long) + else: + return input_ids.new_ones(input_ids.shape, dtype=torch.long) def _prepare_encoder_decoder_kwargs_for_generation( self, input_ids: torch.LongTensor, model_kwargs @@ -417,12 +428,11 @@ class GenerationMixin: return model_kwargs def _prepare_decoder_input_ids_for_generation( - self, input_ids: torch.LongTensor, decoder_start_token_id: int = None, bos_token_id: int = None + self, batch_size: int, decoder_start_token_id: int = None, bos_token_id: int = None ) -> torch.LongTensor: decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id) - decoder_input_ids = ( - torch.ones((input_ids.shape[0], 1), dtype=torch.long, device=input_ids.device) * decoder_start_token_id - ) + + decoder_input_ids = torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * decoder_start_token_id return decoder_input_ids def _get_pad_token_id(self, pad_token_id: int = None, eos_token_id: int = None) -> int: @@ -890,8 +900,9 @@ class GenerationMixin: if model_kwargs.get("attention_mask", None) is None: # init `attention_mask` depending on `pad_token_id` + inputs_embeds = model_kwargs.get("inputs_embeds", None) model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation( - input_ids, pad_token_id, eos_token_id + input_ids, pad_token_id, eos_token_id, inputs_embeds ) # special case if pad_token_id is not defined @@ -910,12 +921,17 @@ class GenerationMixin: if "decoder_input_ids" in model_kwargs: input_ids = model_kwargs.pop("decoder_input_ids") else: + # if word embeddings are provided directly, infere the batch size from it + batch_size = input_ids.shape[0] if input_ids is not None else model_kwargs["inputs_embeds"].shape[0] input_ids = self._prepare_decoder_input_ids_for_generation( - input_ids, decoder_start_token_id=decoder_start_token_id, bos_token_id=bos_token_id + batch_size, decoder_start_token_id=decoder_start_token_id, bos_token_id=bos_token_id ) if "encoder_outputs" not in model_kwargs or not isinstance(model_kwargs["encoder_outputs"], ModelOutput): raise ValueError("Make sure that `model_kwargs` include `encoder_outputs` of type `ModelOutput`.") + else: + if "inputs_embeds" in model_kwargs and input_ids is None: + raise ValueError("For decoder-only generation, one must pass `input_ids`.") # if `max_new_tokens` is passed, but not `max_length` -> set `max_length = max_new_tokens` if max_length is None and max_new_tokens is not None: diff --git a/tests/test_generation_utils.py b/tests/test_generation_utils.py index caf5ccf464..6f2b3c4a3c 100644 --- a/tests/test_generation_utils.py +++ b/tests/test_generation_utils.py @@ -1388,15 +1388,17 @@ class GenerationIntegrationTests(unittest.TestCase): def test_max_length_backward_compat_greedy(self): article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" - bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random") - bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device) + bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart") + bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to( + torch_device + ) input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device) max_length = 20 input_ids = input_ids.expand(2, -1) model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {}) input_ids = bart_model._prepare_decoder_input_ids_for_generation( - input_ids, + input_ids.shape[0], decoder_start_token_id=bart_model.config.decoder_start_token_id, bos_token_id=bart_model.config.bos_token_id, ) @@ -1412,15 +1414,17 @@ class GenerationIntegrationTests(unittest.TestCase): def test_max_length_backward_compat_sample(self): article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" - bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random") - bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device) + bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart") + bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to( + torch_device + ) input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device) max_length = 20 input_ids = input_ids.expand(2, -1) model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {}) input_ids = bart_model._prepare_decoder_input_ids_for_generation( - input_ids, + input_ids.shape[0], decoder_start_token_id=bart_model.config.decoder_start_token_id, bos_token_id=bart_model.config.bos_token_id, ) @@ -1436,8 +1440,10 @@ class GenerationIntegrationTests(unittest.TestCase): def test_max_length_backward_compat_beam_search(self): article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" - bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random") - bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device) + bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart") + bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to( + torch_device + ) input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device) batch_size = 1 @@ -1447,7 +1453,7 @@ class GenerationIntegrationTests(unittest.TestCase): input_ids = input_ids.expand(2, -1) model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {}) input_ids = bart_model._prepare_decoder_input_ids_for_generation( - input_ids, + input_ids.shape[0], decoder_start_token_id=bart_model.config.decoder_start_token_id, bos_token_id=bart_model.config.bos_token_id, ) @@ -1464,8 +1470,10 @@ class GenerationIntegrationTests(unittest.TestCase): def test_max_length_backward_compat_group_beam_search(self): article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" - bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random") - bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device) + bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart") + bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to( + torch_device + ) input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device) batch_size = 1 @@ -1477,7 +1485,7 @@ class GenerationIntegrationTests(unittest.TestCase): input_ids = input_ids.expand(6, -1) model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {}) input_ids = bart_model._prepare_decoder_input_ids_for_generation( - input_ids, + input_ids.shape[0], decoder_start_token_id=bart_model.config.decoder_start_token_id, bos_token_id=bart_model.config.bos_token_id, ) @@ -1496,8 +1504,10 @@ class GenerationIntegrationTests(unittest.TestCase): def test_max_length_warning_if_different(self): article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" - bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random") - bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device) + bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart") + bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to( + torch_device + ) input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device) batch_size = 1 @@ -1513,7 +1523,7 @@ class GenerationIntegrationTests(unittest.TestCase): input_ids = input_ids.expand(6, -1) model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {}) input_ids = bart_model._prepare_decoder_input_ids_for_generation( - input_ids, + input_ids.shape[0], decoder_start_token_id=bart_model.config.decoder_start_token_id, bos_token_id=bart_model.config.bos_token_id, ) @@ -1577,8 +1587,10 @@ class GenerationIntegrationTests(unittest.TestCase): def test_beam_search_warning_if_max_length_is_passed(self): article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" - bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random") - bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device) + bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart") + bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to( + torch_device + ) batch_size = 1 num_beams = 3 @@ -1587,6 +1599,9 @@ class GenerationIntegrationTests(unittest.TestCase): input_ids = input_ids.expand(num_beams, -1) model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {}) + # pretend decoder_input_ids correspond to first encoder input id + decoder_input_ids = input_ids[:, :1] + stopping_criteria_max_length = 18 stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=stopping_criteria_max_length)]) @@ -1599,7 +1614,7 @@ class GenerationIntegrationTests(unittest.TestCase): ) generated_ids = bart_model.beam_search( - input_ids, + decoder_input_ids, num_beams=num_beams, stopping_criteria=stopping_criteria, beam_scorer=beam_scorer, @@ -1613,7 +1628,7 @@ class GenerationIntegrationTests(unittest.TestCase): ) generated_ids_no_max_len = bart_model.beam_search( - input_ids, + decoder_input_ids, num_beams=num_beams, stopping_criteria=stopping_criteria, beam_scorer=beam_scorer_no_max_len, @@ -1625,14 +1640,17 @@ class GenerationIntegrationTests(unittest.TestCase): def test_max_new_tokens_encoder_decoder(self): article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" - bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random") - bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device) + bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart") + bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to( + torch_device + ) input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device) - self.assertEqual(list(input_ids.shape), [1, 15]) + self.assertEqual(list(input_ids.shape), [1, 29]) max_new_tokens = 3 bart_model.config.max_length = 20 + bart_model.config.eos_token_id = None # Encoder decoder call outputs = bart_model.generate(input_ids, max_new_tokens=max_new_tokens) @@ -1641,8 +1659,8 @@ class GenerationIntegrationTests(unittest.TestCase): # Decoder only call outputs = bart_model.generate(decoder_input_ids=input_ids, max_new_tokens=max_new_tokens) - # 15 + 3 new tokens - self.assertEqual(list(outputs.shape), [1, 18]) + # 29 + 3 new tokens + self.assertEqual(list(outputs.shape), [1, 32]) # Encoder decoder call > 20 outputs = bart_model.generate(max_new_tokens=max_new_tokens + 20) @@ -1680,3 +1698,29 @@ class GenerationIntegrationTests(unittest.TestCase): # max_new_tokens and max_length serve the same purpose and should not be used together. with self.assertWarns(UserWarning): gpt2_model.generate(decoder_input_ids=input_ids, max_new_tokens=10, max_length=20) + + def test_encoder_decoder_generate_with_inputs_embeds(self): + article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" + tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart") + model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart", max_length=5).to( + torch_device + ) + model.config.eos_token_id = None + input_ids = tokenizer(article, return_tensors="pt").input_ids.to(torch_device) + inputs_embeds = model.get_input_embeddings()(input_ids) + + output_sequences = model.generate(inputs_embeds=inputs_embeds) + + # make sure model generated correctly until `max_length` + self.assertEqual(output_sequences.shape, (1, 5)) + + def test_decoder_generate_with_inputs_embeds(self): + article = """I need input_ids to generate""" + tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") + model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2", max_length=5).to(torch_device) + input_ids = tokenizer(article, return_tensors="pt").input_ids.to(torch_device) + inputs_embeds = model.get_input_embeddings()(input_ids) + + # cannot generate from `inputs_embeds` for decoder only + with self.assertRaises(ValueError): + model.generate(inputs_embeds=inputs_embeds)