From b18d8534ea62f144a4002b9e2afcb4588518e945 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 16 Dec 2021 18:03:55 +0100 Subject: [PATCH] [Generate] Make generate multi-modal (#14784) * finish refactor * refactor * add tests * add more tests * up * finish tests * finish * up * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * improve docstring * fix docs Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- src/transformers/generation_utils.py | 264 +++++++++++++++++---------- tests/test_generation_utils.py | 76 ++++++++ 2 files changed, 243 insertions(+), 97 deletions(-) diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index 4ba16f6be7..e0f8bd1651 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -359,12 +359,72 @@ BeamSearchOutput = Union[BeamSearchEncoderDecoderOutput, BeamSearchDecoderOnlyOu BeamSampleOutput = Union[BeamSampleEncoderDecoderOutput, BeamSampleDecoderOnlyOutput] +ENCODER_MODEL_INPUT_NAMES = ["input_ids", "inputs_embeds", "input_values", "input_features", "pixel_values"] + + class GenerationMixin: """ A class containing all of the functions supporting generation, to be used as a mixin in :class:`~transformers.PreTrainedModel`. """ + def _prepare_model_inputs( + self, + inputs: Optional[torch.Tensor] = None, + bos_token_id: Optional[int] = None, + model_kwargs: Optional[Dict[str, torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, Optional[str]]: + """ + This function extracts the model-specific `inputs` for generation. + """ + # filter model input names that are `None` + model_kwargs = {k: v for k, v in model_kwargs.items() if k not in ENCODER_MODEL_INPUT_NAMES or v is not None} + # extract keyword arguments that are model input specific + model_input_kwarg_names = set(ENCODER_MODEL_INPUT_NAMES) & set(model_kwargs.keys()) + + # There are 5 possible scenarios + if inputs is not None and len(model_input_kwarg_names) == 0: + # 1. `inputs` are passed and no model-specific keyword inputs + # -> return input + model_input_name = None + return inputs, model_input_name, model_kwargs + elif inputs is not None and len(model_input_kwarg_names) > 0: + # 2. `inputs` are passed as well as model-specific keyword inputs + # -> not allowed, raise Error + raise ValueError( + f"`inputs`: {inputs}` were passed alongside " + f"{model_input_kwarg_names} which is not allowed." + f"Make sure to not pass any of {model_input_kwarg_names} " + "when `inputs` is defined." + ) + elif inputs is None and len(model_input_kwarg_names) == 0: + # 3. no `inputs` and no model-specific keyword inputs are passed + # -> try to create `input_ids` from BOS + input_tensor = self._prepare_input_ids_for_generation(bos_token_id, model_kwargs.get("encoder_outputs")) + return input_tensor, "input_ids", model_kwargs + elif inputs is None and len(model_input_kwarg_names) == 1: + # 4. no `inputs` are passed and exactly one model-specific keyword input + # -> return that model-specific keyword input tensor + model_input_name = model_input_kwarg_names.pop() + input_tensor = model_kwargs.pop(model_input_name) + + # make sure model is encoder decoder if not `input_ids` + if not self.config.is_encoder_decoder and model_input_name != "input_ids": + raise ValueError( + f"If {model_input_name} is passed as model-specific keyword " + "input then model has to be an encoder-decoder and not a " + f"{self.__class__.__name__}." + ) + return input_tensor, model_input_name, model_kwargs + else: + # 5. no `inputs` are passed and multiple model-specific keyword inputs + # -> not allowed, raise Error + raise ValueError( + f"Can only pass one of {ENCODER_MODEL_INPUT_NAMES}, " + f"but passed {model_input_kwarg_names}." + f"Make sure to only pass one of {model_input_kwarg_names}." + ) + def prepare_inputs_for_generation(self, input_ids: torch.LongTensor, **kwargs) -> Dict[str, Any]: """ Implement in subclasses of :class:`~transformers.PreTrainedModel` for custom behavior to prepare inputs in the @@ -393,47 +453,63 @@ class GenerationMixin: def _prepare_attention_mask_for_generation( self, - input_ids: torch.Tensor, + inputs: torch.Tensor, pad_token_id: int, eos_token_id: int, - inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.LongTensor: - - # First if `inputs_embeds` are given, but no `attention_mask` assume that full attention_mask is used - if inputs_embeds is not None: - return torch.ones((inputs_embeds.shape[0], inputs_embeds.shape[1]), dtype=torch.long, device=self.device) - - # Otherwise, use `input_ids` - is_pad_token_in_inputs_ids = (pad_token_id is not None) and (pad_token_id in input_ids) + is_input_ids = isinstance(inputs, torch.LongTensor) and len(inputs.shape) == 2 + is_pad_token_in_inputs = (pad_token_id is not None) and (pad_token_id in inputs) is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or ( (eos_token_id is not None) and (pad_token_id != eos_token_id) ) - if is_pad_token_in_inputs_ids and is_pad_token_not_equal_to_eos_token_id: - return input_ids.ne(pad_token_id).long() + # Check if input is input_ids and padded -> only then is attention_mask defined + if is_input_ids and is_pad_token_in_inputs and is_pad_token_not_equal_to_eos_token_id: + return inputs.ne(pad_token_id).long() else: - return input_ids.new_ones(input_ids.shape, dtype=torch.long) + return torch.ones(inputs.shape[:2], dtype=torch.long, device=self.device) def _prepare_encoder_decoder_kwargs_for_generation( - self, input_ids: torch.LongTensor, model_kwargs + self, inputs_tensor: torch.Tensor, model_kwargs, model_input_name: Optional[str] = None ) -> Dict[str, Any]: if "encoder_outputs" not in model_kwargs: - # retrieve encoder hidden states + # 1. get encoder encoder = self.get_encoder() + # 2. prepare encoder args and encoder kwargs from model kwargs + encoder_args = (inputs_tensor,) + irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"] encoder_kwargs = { argument: value for argument, value in model_kwargs.items() - if not (argument.startswith("decoder_") or argument.startswith("cross_attn")) + if not any(argument.startswith(p) for p in irrelevant_prefix) } - model_kwargs["encoder_outputs"]: ModelOutput = encoder(input_ids, return_dict=True, **encoder_kwargs) + # 3. make sure that encoder returns `ModelOutput` + encoder_kwargs["return_dict"] = True + + # 4. if model_input_name is not defined then pass input_tensor as + # first input argument and remove from args + if model_input_name is not None: + # make sure inputs_tensor is None in case model + # accepts multiple model input arguments + encoder_kwargs[model_input_name] = inputs_tensor + encoder_args = () + + model_kwargs["encoder_outputs"]: ModelOutput = encoder(*encoder_args, **encoder_kwargs) + return model_kwargs def _prepare_decoder_input_ids_for_generation( - self, batch_size: int, decoder_start_token_id: int = None, bos_token_id: int = None + self, + batch_size: int, + decoder_start_token_id: int = None, + bos_token_id: int = None, + model_kwargs: Optional[Dict[str, torch.Tensor]] = None, ) -> torch.LongTensor: - decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id) - decoder_input_ids = torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * decoder_start_token_id - return decoder_input_ids + if model_kwargs is not None and "decoder_input_ids" in model_kwargs: + return model_kwargs.pop("decoder_input_ids") + else: + decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id) + return torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * decoder_start_token_id def _get_pad_token_id(self, pad_token_id: int = None, eos_token_id: int = None) -> int: if pad_token_id is None and eos_token_id is not None: @@ -649,7 +725,7 @@ class GenerationMixin: @torch.no_grad() def generate( self, - input_ids: Optional[torch.LongTensor] = None, + inputs: Optional[torch.Tensor] = None, max_length: Optional[int] = None, min_length: Optional[int] = None, do_sample: Optional[bool] = None, @@ -688,18 +764,20 @@ class GenerationMixin: Generates sequences for models with a language modeling head. The method currently supports greedy decoding, multinomial sampling, beam-search decoding, and beam-search multinomial sampling. - Apart from :obj:`input_ids` and :obj:`attention_mask`, all the arguments below will default to the value of the - attribute of the same name inside the :class:`~transformers.PretrainedConfig` of the model. The default values - indicated are the default values of those config. + Apart from :obj:`inputs`, all the arguments below will default to the value of the attribute of the same name + inside the :class:`~transformers.PretrainedConfig` of the model. The default values indicated are the default + values of those config. Most of these parameters are explained in more detail in `this blog post `__. Parameters: - input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): - The sequence used as a prompt for the generation. If :obj:`None` the method initializes it with - :obj:`bos_token_id` and a batch size of 1. + inputs (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, :obj:`(batch_size, sequence_length, feature_dim)` or :obj:`(batch_size, num_channels, height, width)`, `optional`): + The sequence used as a prompt for the generation or as model inputs to the encoder. If :obj:`None` the + method initializes it with :obj:`bos_token_id` and a batch size of 1. For decoder-only models + :obj:`inputs` should of in the format of :obj:`input_ids`. For encoder-decoder models `inputs` can + represent any of :obj:`input_ids`, :obj:`input_values`, :obj:`input_features`, or :obj:`pixel_values`. max_length (:obj:`int`, `optional`, defaults to :obj:`model.config.max_length`): The maximum length of the sequence to be generated. max_new_tokens (:obj:`int`, `optional`, defaults to None): @@ -870,8 +948,11 @@ class GenerationMixin: >>> outputs = model.generate(input_ids=input_ids, max_length=20, do_sample=True, bad_words_ids=bad_words_ids) >>> print("Generated:", tokenizer.decode(outputs[0], skip_special_tokens=True)) """ - + # 1. Set generation parameters if not already defined + bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id num_beams = num_beams if num_beams is not None else self.config.num_beams + length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty + early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping num_beam_groups = num_beam_groups if num_beam_groups is not None else self.config.num_beam_groups do_sample = do_sample if do_sample is not None else self.config.do_sample num_return_sequences = ( @@ -879,7 +960,6 @@ class GenerationMixin: ) pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id - bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id output_scores = output_scores if output_scores is not None else self.config.output_scores @@ -891,55 +971,52 @@ class GenerationMixin: return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate ) - model_kwargs["output_attentions"] = output_attentions - model_kwargs["output_hidden_states"] = output_hidden_states - - if input_ids is None and "inputs_embeds" not in model_kwargs: - # init `input_ids` with bos_token_id - input_ids = self._prepare_input_ids_for_generation(bos_token_id, model_kwargs.get("encoder_outputs")) - - if model_kwargs.get("attention_mask", None) is None: - # init `attention_mask` depending on `pad_token_id` - inputs_embeds = model_kwargs.get("inputs_embeds", None) - model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation( - input_ids, pad_token_id, eos_token_id, inputs_embeds - ) - - # special case if pad_token_id is not defined if pad_token_id is None and eos_token_id is not None: + # special case if pad_token_id is not defined logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.") pad_token_id = eos_token_id - # Storing encoder_input_ids for logits_processor that could use them - encoder_input_ids = input_ids if self.config.is_encoder_decoder else None + # 2. Define model inputs + # inputs_tensor has to be defined + # model_input_name is defined if model-specific keyword input is passed + # otherwise model_input_name is None + # all model-specific keyword inputs are removed from `model_kwargs` + inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(inputs, bos_token_id, model_kwargs) + batch_size = inputs_tensor.shape[0] + + # 3. Define other model kwargs + model_kwargs["output_attentions"] = output_attentions + model_kwargs["output_hidden_states"] = output_hidden_states + model_kwargs["use_cache"] = use_cache + + if model_kwargs.get("attention_mask", None) is None: + model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation( + inputs_tensor, pad_token_id, eos_token_id + ) if self.config.is_encoder_decoder: - # add encoder_outputs to model_kwargs - model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(input_ids, model_kwargs) + # if model is encoder decoder encoder_outputs are created + # and added to `model_kwargs` + model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation( + inputs_tensor, model_kwargs, model_input_name + ) - # set input_ids as decoder_input_ids - if "decoder_input_ids" in model_kwargs: - input_ids = model_kwargs.pop("decoder_input_ids") - else: - # if word embeddings are provided directly, infere the batch size from it - batch_size = input_ids.shape[0] if input_ids is not None else model_kwargs["inputs_embeds"].shape[0] - input_ids = self._prepare_decoder_input_ids_for_generation( - batch_size, decoder_start_token_id=decoder_start_token_id, bos_token_id=bos_token_id - ) - - if "encoder_outputs" not in model_kwargs or not isinstance(model_kwargs["encoder_outputs"], ModelOutput): - raise ValueError("Make sure that `model_kwargs` include `encoder_outputs` of type `ModelOutput`.") + # 4. Prepare `input_ids` which will be used for auto-regressive generation + if self.config.is_encoder_decoder: + input_ids = self._prepare_decoder_input_ids_for_generation( + batch_size, + decoder_start_token_id=decoder_start_token_id, + bos_token_id=bos_token_id, + model_kwargs=model_kwargs, + ) else: - if "inputs_embeds" in model_kwargs and input_ids is None: - raise ValueError("For decoder-only generation, one must pass `input_ids`.") + # if decoder-only then inputs_tensor has to be `input_ids` + input_ids = inputs_tensor + # 5. Prepare `max_length` depending on other stopping criteria # if `max_new_tokens` is passed, but not `max_length` -> set `max_length = max_new_tokens` if max_length is None and max_new_tokens is not None: - max_length = ( - max_new_tokens + input_ids.shape[-1] - if input_ids is not None - else max_length + model_kwargs["inputs_embeds"].shape[1] - ) + max_length = max_new_tokens + input_ids.shape[-1] elif max_length is not None and max_new_tokens is not None: # Both are set, this is odd, raise a warning warnings.warn( @@ -948,7 +1025,6 @@ class GenerationMixin: f"will take priority over `max_new_tokens` {max_new_tokens}.", UserWarning, ) - # default to config if still None max_length = max_length if max_length is not None else self.config.max_length @@ -959,12 +1035,13 @@ class GenerationMixin: "This can lead to unexpected behavior. You should consider increasing ``config.max_length`` or ``max_length``." ) - # determine generation mode + # 6. determine generation mode is_greedy_gen_mode = (num_beams == 1) and (num_beam_groups == 1) and do_sample is False is_sample_gen_mode = (num_beams == 1) and (num_beam_groups == 1) and do_sample is True is_beam_gen_mode = (num_beams > 1) and (num_beam_groups == 1) and do_sample is False is_beam_sample_gen_mode = (num_beams > 1) and (num_beam_groups == 1) and do_sample is True is_group_beam_gen_mode = (num_beams > 1) and (num_beam_groups > 1) + if num_beam_groups > num_beams: raise ValueError("`num_beam_groups` has to be smaller or equal to `num_beams`") if is_group_beam_gen_mode and do_sample is True: @@ -972,15 +1049,12 @@ class GenerationMixin: "Diverse beam search cannot be used in sampling mode. Make sure that `do_sample` is set to `False`." ) - # set model_kwargs - model_kwargs["use_cache"] = use_cache - - # get distribution pre_processing samplers + # 7. prepare distribution pre_processing samplers logits_processor = self._get_logits_processor( repetition_penalty=repetition_penalty, no_repeat_ngram_size=no_repeat_ngram_size, encoder_no_repeat_ngram_size=encoder_no_repeat_ngram_size, - encoder_input_ids=encoder_input_ids, + encoder_input_ids=inputs_tensor, bad_words_ids=bad_words_ids, min_length=min_length, max_length=max_length, @@ -994,15 +1068,17 @@ class GenerationMixin: remove_invalid_values=remove_invalid_values, ) + # 8. prepare stopping criteria stopping_criteria = self._get_stopping_criteria(max_length=max_length, max_time=max_time) + # 9. go into different generation modes if is_greedy_gen_mode: if num_return_sequences > 1: raise ValueError( f"num_return_sequences has to be 1, but is {num_return_sequences} when doing greedy search." ) - # greedy search + # 10. run greedy search return self.greedy_search( input_ids, logits_processor=logits_processor, @@ -1016,12 +1092,12 @@ class GenerationMixin: ) elif is_sample_gen_mode: - # get probability distribution warper + # 10. prepare logits warper logits_warper = self._get_logits_warper( top_k=top_k, top_p=top_p, temperature=temperature, num_beams=num_beams ) - # expand input_ids with `num_return_sequences` additional sequences per batch + # 11. expand input_ids with `num_return_sequences` additional sequences per batch input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids, expand_size=num_return_sequences, @@ -1029,7 +1105,7 @@ class GenerationMixin: **model_kwargs, ) - # sample + # 12. run sample return self.sample( input_ids, logits_processor=logits_processor, @@ -1044,17 +1120,13 @@ class GenerationMixin: ) elif is_beam_gen_mode: - batch_size = input_ids.shape[0] - - length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty - early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping - if num_return_sequences > num_beams: raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.") if stopping_criteria.max_length is None: raise ValueError("`max_length` needs to be a stopping_criteria for now.") + # 10. prepare beam search scorer beam_scorer = BeamSearchScorer( batch_size=batch_size, num_beams=num_beams, @@ -1063,10 +1135,11 @@ class GenerationMixin: do_early_stopping=early_stopping, num_beam_hyps_to_keep=num_return_sequences, ) - # interleave with `num_beams` + # 11. interleave input_ids with `num_beams` additional sequences per batch input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids, expand_size=num_beams, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs ) + # 12. run beam search return self.beam_search( input_ids, beam_scorer, @@ -1081,24 +1154,23 @@ class GenerationMixin: ) elif is_beam_sample_gen_mode: + # 10. prepare logits warper logits_warper = self._get_logits_warper( top_k=top_k, top_p=top_p, temperature=temperature, num_beams=num_beams ) - batch_size = input_ids.shape[0] * num_return_sequences - - length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty if stopping_criteria.max_length is None: raise ValueError("`max_length` needs to be a stopping_criteria for now.") + # 11. prepare beam search scorer beam_scorer = BeamSearchScorer( - batch_size=batch_size, + batch_size=batch_size * num_return_sequences, num_beams=num_beams, device=self.device, length_penalty=length_penalty, do_early_stopping=early_stopping, ) - # interleave with `num_beams * num_return_sequences` + # 12. interleave input_ids with `num_beams` additional sequences per batch input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids, expand_size=num_beams * num_return_sequences, @@ -1106,6 +1178,7 @@ class GenerationMixin: **model_kwargs, ) + # 13. run beam sample return self.beam_sample( input_ids, beam_scorer, @@ -1121,11 +1194,6 @@ class GenerationMixin: ) elif is_group_beam_gen_mode: - batch_size = input_ids.shape[0] - - length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty - early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping - if num_return_sequences > num_beams: raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.") @@ -1135,7 +1203,8 @@ class GenerationMixin: if stopping_criteria.max_length is None: raise ValueError("`max_length` needs to be a stopping_criteria for now.") - diverse_beam_scorer = BeamSearchScorer( + # 10. prepare beam search scorer + beam_scorer = BeamSearchScorer( batch_size=batch_size, num_beams=num_beams, max_length=stopping_criteria.max_length, @@ -1145,13 +1214,14 @@ class GenerationMixin: num_beam_hyps_to_keep=num_return_sequences, num_beam_groups=num_beam_groups, ) - # interleave with `num_beams` + # 11. interleave input_ids with `num_beams` additional sequences per batch input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids, expand_size=num_beams, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs ) + # 12. run beam search return self.group_beam_search( input_ids, - diverse_beam_scorer, + beam_scorer, logits_processor=logits_processor, stopping_criteria=stopping_criteria, pad_token_id=pad_token_id, diff --git a/tests/test_generation_utils.py b/tests/test_generation_utils.py index 6f2b3c4a3c..edd6e4533c 100644 --- a/tests/test_generation_utils.py +++ b/tests/test_generation_utils.py @@ -20,6 +20,8 @@ import unittest from transformers import is_torch_available from transformers.testing_utils import require_torch, slow, torch_device +from .test_modeling_common import floats_tensor + if is_torch_available(): import torch @@ -29,6 +31,9 @@ if is_torch_available(): BartTokenizer, GPT2LMHeadModel, GPT2Tokenizer, + Speech2TextForConditionalGeneration, + SpeechEncoderDecoderModel, + VisionEncoderDecoderModel, top_k_top_p_filtering, ) from transformers.generation_beam_search import BeamSearchScorer @@ -1724,3 +1729,74 @@ class GenerationIntegrationTests(unittest.TestCase): # cannot generate from `inputs_embeds` for decoder only with self.assertRaises(ValueError): model.generate(inputs_embeds=inputs_embeds) + + def test_generate_input_ids_as_kwarg(self): + article = """I need input_ids to generate""" + tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") + model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2", max_length=15).to(torch_device) + input_ids = tokenizer(article, return_tensors="pt").input_ids.to(torch_device) + output_sequences_kwargs = model.generate(input_ids=input_ids).cpu() + output_sequences = model.generate(input_ids).cpu() + + self.assertListEqual(output_sequences.tolist(), output_sequences_kwargs.tolist()) + self.assertEqual(output_sequences.shape, (1, 15)) + + def test_generate_input_ids_as_encoder_kwarg(self): + article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" + tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart") + model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart", max_length=5).to( + torch_device + ) + model.config.eos_token_id = None + input_ids = tokenizer(article, return_tensors="pt").input_ids.to(torch_device) + output_sequences_kwargs = model.generate(input_ids=input_ids).cpu() + output_sequences = model.generate(input_ids).cpu() + + self.assertListEqual(output_sequences.tolist(), output_sequences_kwargs.tolist()) + self.assertEqual(output_sequences.shape, (1, 5)) + + def test_generate_inputs_and_encoder_kwargs(self): + article = """I need input_ids to generate""" + tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") + model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2", max_length=10).to(torch_device) + input_ids = tokenizer(article, return_tensors="pt").input_ids.to(torch_device) + with self.assertRaises(ValueError): + model.generate(input_ids, input_ids=input_ids) + + def test_generate_too_many_encoder_kwargs(self): + article = """I need input_ids to generate""" + tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") + model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2", max_length=10).to(torch_device) + input_ids = tokenizer(article, return_tensors="pt").input_ids.to(torch_device) + with self.assertRaises(ValueError): + model.generate(input_ids=input_ids, input_values=input_ids) + + def test_generate_input_values_as_encoder_kwarg(self): + input_values = floats_tensor((2, 250)) + model = SpeechEncoderDecoderModel.from_pretrained("hf-internal-testing/tiny-random-speech-encoder-decoder") + model = model.to(torch_device) + output_sequences_kwargs = model.generate(input_values=input_values, max_length=5).cpu() + output_sequences = model.generate(input_values, max_length=5).cpu() + + self.assertListEqual(output_sequences.tolist(), output_sequences_kwargs.tolist()) + self.assertEqual(output_sequences.shape, (2, 5)) + + def test_generate_input_features_as_encoder_kwarg(self): + input_features = floats_tensor((3, 20, 24)) + model = Speech2TextForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-speech_to_text") + model = model.to(torch_device) + output_sequences_kwargs = model.generate(input_features=input_features, max_length=5).cpu() + output_sequences = model.generate(input_features, max_length=5).cpu() + + self.assertListEqual(output_sequences.tolist(), output_sequences_kwargs.tolist()) + self.assertEqual(output_sequences.shape, (3, 5)) + + def test_generate_pixel_values_as_encoder_kwarg(self): + pixel_values = floats_tensor((2, 3, 30, 30)) + model = VisionEncoderDecoderModel.from_pretrained("hf-internal-testing/tiny-random-vision-encoder-decoder") + model = model.to(torch_device) + output_sequences_kwargs = model.generate(pixel_values=pixel_values, max_length=5).cpu() + output_sequences = model.generate(pixel_values, max_length=5).cpu() + + self.assertListEqual(output_sequences.tolist(), output_sequences_kwargs.tolist()) + self.assertEqual(output_sequences.shape, (2, 5))