diff --git a/docs/source/internal/generation_utils.mdx b/docs/source/internal/generation_utils.mdx index 9eb4abe06d..5ee321c0a4 100644 --- a/docs/source/internal/generation_utils.mdx +++ b/docs/source/internal/generation_utils.mdx @@ -148,6 +148,24 @@ generation. [[autodoc]] InfNanRemoveLogitsProcessor - __call__ +[[autodoc]] TFLogitsProcessor + - __call__ + +[[autodoc]] TFLogitsProcessorList + - __call__ + +[[autodoc]] TFMinLengthLogitsProcessor + - __call__ + +[[autodoc]] TFNoBadWordsLogitsProcessor + - __call__ + +[[autodoc]] TFNoRepeatNGramLogitsProcessor + - __call__ + +[[autodoc]] TFRepetitionPenaltyLogitsProcessor + - __call__ + [[autodoc]] FlaxLogitsProcessor - __call__ diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index f4b0e2908b..ad05486104 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1592,6 +1592,14 @@ if is_tf_available(): _import_structure["activations_tf"] = [] _import_structure["benchmark.benchmark_args_tf"] = ["TensorFlowBenchmarkArguments"] _import_structure["benchmark.benchmark_tf"] = ["TensorFlowBenchmark"] + _import_structure["generation_tf_logits_process"] = [ + "TFLogitsProcessor", + "TFLogitsProcessorList", + "TFMinLengthLogitsProcessor", + "TFNoBadWordsLogitsProcessor", + "TFNoRepeatNGramLogitsProcessor", + "TFRepetitionPenaltyLogitsProcessor", + ] _import_structure["generation_tf_utils"] = ["tf_top_k_top_p_filtering"] _import_structure["keras_callbacks"] = ["KerasMetricCallback", "PushToHubCallback"] _import_structure["modeling_tf_outputs"] = [] @@ -2046,6 +2054,7 @@ if is_tf_available(): ] ) _import_structure["optimization_tf"] = ["AdamWeightDecay", "GradientAccumulator", "WarmUp", "create_optimizer"] + _import_structure["tf_utils"] = [] _import_structure["trainer_tf"] = ["TFTrainer"] else: @@ -3572,6 +3581,14 @@ if TYPE_CHECKING: # Benchmarks from .benchmark.benchmark_tf import TensorFlowBenchmark + from .generation_tf_logits_process import ( + TFLogitsProcessor, + TFLogitsProcessorList, + TFMinLengthLogitsProcessor, + TFNoBadWordsLogitsProcessor, + TFNoRepeatNGramLogitsProcessor, + TFRepetitionPenaltyLogitsProcessor, + ) from .generation_tf_utils import tf_top_k_top_p_filtering from .keras_callbacks import KerasMetricCallback, PushToHubCallback from .modeling_tf_layoutlm import ( diff --git a/src/transformers/generation_flax_logits_process.py b/src/transformers/generation_flax_logits_process.py index 1d66953413..76a09ed012 100644 --- a/src/transformers/generation_flax_logits_process.py +++ b/src/transformers/generation_flax_logits_process.py @@ -14,7 +14,6 @@ # limitations under the License. import inspect -from abc import ABC import jax import jax.lax as lax @@ -48,7 +47,7 @@ LOGITS_PROCESSOR_INPUTS_DOCSTRING = r""" """ -class FlaxLogitsProcessor(ABC): +class FlaxLogitsProcessor: """Abstract base class for all logit processors that can be applied during generation.""" @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) @@ -59,7 +58,7 @@ class FlaxLogitsProcessor(ABC): ) -class FlaxLogitsWarper(ABC): +class FlaxLogitsWarper: """Abstract base class for all logit warpers that can be applied during generation with multinomial sampling.""" @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) diff --git a/src/transformers/generation_logits_process.py b/src/transformers/generation_logits_process.py index 10fbd24ece..18f8c5971f 100644 --- a/src/transformers/generation_logits_process.py +++ b/src/transformers/generation_logits_process.py @@ -15,7 +15,6 @@ import inspect import math -from abc import ABC from typing import Callable, Iterable, List, Optional import numpy as np @@ -49,7 +48,7 @@ LOGITS_PROCESSOR_INPUTS_DOCSTRING = r""" """ -class LogitsProcessor(ABC): +class LogitsProcessor: """Abstract base class for all logit processors that can be applied during generation.""" @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) @@ -60,7 +59,7 @@ class LogitsProcessor(ABC): ) -class LogitsWarper(ABC): +class LogitsWarper: """Abstract base class for all logit warpers that can be applied during generation with multinomial sampling.""" @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) diff --git a/src/transformers/generation_tf_logits_process.py b/src/transformers/generation_tf_logits_process.py new file mode 100644 index 0000000000..74a6176856 --- /dev/null +++ b/src/transformers/generation_tf_logits_process.py @@ -0,0 +1,295 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import List + +import numpy as np +import tensorflow as tf + +from .file_utils import add_start_docstrings +from .tf_utils import set_tensor_by_indices_to_value +from .utils.logging import get_logger + + +logger = get_logger(__name__) + + +TF_LOGITS_PROCESSOR_INPUTS_DOCSTRING = r""" + Args: + input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + scores (`tf.Tensor` of shape `(batch_size, config.vocab_size)`): + Prediction scores of a language modeling head. These can be logits for each vocabulary when not using beam + search or log softmax for each vocabulary token when using beam search + kwargs: + Additional logits processor specific kwargs. + + Return: + `tf.Tensor` of shape `(batch_size, config.vocab_size)`: The processed prediction scores. +""" + + +class TFLogitsProcessor: + """Abstract base class for all logit processors that can be applied during generation.""" + + @add_start_docstrings(TF_LOGITS_PROCESSOR_INPUTS_DOCSTRING) + def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor) -> tf.Tensor: + """TF method for processing logits.""" + raise NotImplementedError( + f"{self.__class__} is an abstract class. Only classes inheriting this class can be called." + ) + + +class TFLogitsProcessorList(list): + """ + This class can be used to create a list of [`TFLogitsProcessor`] to subsequently process a `scores` input tensor. + This class inherits from list and adds a specific *__call__* method to apply each [`TFLogitsProcessor`] to the + inputs. + """ + + @add_start_docstrings(TF_LOGITS_PROCESSOR_INPUTS_DOCSTRING) + def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, **kwargs) -> tf.Tensor: + for processor in self: + function_args = inspect.signature(processor.__call__).parameters + if len(function_args) > 2: + if not all(arg in kwargs for arg in list(function_args.keys())[2:]): + raise ValueError( + f"Make sure that all the required parameters: {list(function_args.keys())} for " + f"{processor.__class__} are passed to the logits processor." + ) + scores = processor(input_ids, scores, **kwargs) + else: + scores = processor(input_ids, scores) + return scores + + +class TFMinLengthLogitsProcessor(TFLogitsProcessor): + r""" + [`TFLogitsProcessor`] enforcing a min-length by setting EOS probability to 0. + + Args: + min_length (`int`): + The minimum length below which the score of `eos_token_id` is set to `-float("Inf")`. + eos_token_id (`int`): + The id of the *end-of-sequence* token. + """ + + def __init__(self, min_length: int, eos_token_id: int): + if not isinstance(min_length, int) or min_length < 0: + raise ValueError(f"`min_length` has to be a positive integer, but is {min_length}") + + if not isinstance(eos_token_id, int) or eos_token_id < 0: + raise ValueError(f"`eos_token_id` has to be a positive integer, but is {eos_token_id}") + + self.min_length = min_length + self.eos_token_id = eos_token_id + + def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor) -> tf.Tensor: + # create boolean flag to decide if min length penalty should be applied + cur_len = input_ids.shape[-1] + apply_penalty = 1 - tf.clip_by_value(cur_len - self.min_length, 0, 1) + + # TODO(Matt) - this if statement has to be rewritten for XLA. Leaving it now though since + # generate is not XLA - compileable anyways + if apply_penalty: + eos_token_id_mask = tf.broadcast_to(tf.range(scores.shape[-1]) == self.eos_token_id, scores.shape) + scores = set_tensor_by_indices_to_value(scores, eos_token_id_mask, float("-inf")) + + return scores + + +class TFRepetitionPenaltyLogitsProcessor(TFLogitsProcessor): + r""" + [`TFLogitsProcessor`] enforcing an exponential penalty on repeated sequences. + + Args: + repetition_penalty (`float`): + The parameter for repetition penalty. 1.0 means no penalty. See [this + paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. + """ + + def __init__(self, penalty: float): + if not isinstance(penalty, float) or not (penalty > 0): + raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}") + + self.penalty = penalty + + def _create_score_penalties(self, input_ids, logits): + # create logit penalties for already seen input_ids + token_penalties = np.ones(logits.shape) + prev_input_ids = [np.unique(input_id) for input_id in input_ids.numpy()] + for i, prev_input_id in enumerate(prev_input_ids): + logit_penalized = logits[i].numpy()[prev_input_id] + logit_penalties = np.zeros(logit_penalized.shape) + # if previous logit score is < 0 then multiply repetition penalty else divide + logit_penalties[logit_penalized < 0] = self.penalty + logit_penalties[logit_penalized > 0] = 1 / self.penalty + np.put(token_penalties[i], prev_input_id, logit_penalties) + return tf.convert_to_tensor(token_penalties, dtype=tf.float32) + + def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor) -> tf.Tensor: + + score_penalties = self._create_score_penalties(input_ids, scores) + + scores = tf.math.multiply(scores, score_penalties) + + return scores + + +class TFNoBadWordsLogitsProcessor(TFLogitsProcessor): + """ + [`TFLogitsProcessor`] that enforces that specified sequences will never be sampled. + + Args: + bad_words_ids (`List[List[int]]`): + List of list of token ids that are not allowed to be generated. In order to get the tokens of the words + that should not appear in the generated text, use `tokenizer(bad_word, add_prefix_space=True).input_ids`. + eos_token_id (`int`): + The id of the *end-of-sequence* token. + """ + + def __init__(self, bad_words_ids: List[List[int]], eos_token_id: int): + + if not isinstance(bad_words_ids, List) or len(bad_words_ids) == 0: + raise ValueError(f"`bad_words_ids` has to be a non-emtpy list, but is {bad_words_ids}.") + if any(not isinstance(bad_word_ids, list) for bad_word_ids in bad_words_ids): + raise ValueError(f"`bad_words_ids` has to be a list of lists, but is {bad_words_ids}.") + if any( + any((not isinstance(token_id, (int, np.integer)) or token_id < 0) for token_id in bad_word_ids) + for bad_word_ids in bad_words_ids + ): + raise ValueError( + f"Each list in `bad_words_ids` has to be a list of positive integers, but is {bad_words_ids}." + ) + + self.bad_words_ids = bad_words_ids + + def calc_banned_bad_words_ids(self, prev_input_ids): + banned_tokens = [] + + def _tokens_match(prev_tokens, tokens): + if len(tokens) == 0: + # if bad word tokens is just one token always ban it + return True + if len(tokens) > len(prev_tokens): + # if bad word tokens are longer than prev tokens they can't be equal + return False + + if prev_tokens[-len(tokens) :] == tokens: + # if tokens match + return True + else: + return False + + for prev_input_ids_slice in prev_input_ids: + banned_tokens_slice = [] + + for banned_token_seq in self.bad_words_ids: + assert ( + len(banned_token_seq) > 0 + ), f"Banned words token sequences {self.bad_words_ids} cannot have an empty list" + + if _tokens_match(prev_input_ids_slice.numpy().tolist(), banned_token_seq[:-1]) is False: + # if tokens do not match continue + continue + + banned_tokens_slice.append(banned_token_seq[-1]) + + banned_tokens.append(banned_tokens_slice) + + return banned_tokens + + def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor) -> tf.Tensor: + + vocab_size = scores.shape[-1] + + # calculate a list of banned tokens according to bad words + banned_tokens = self.calc_banned_bad_words_ids(input_ids) + + banned_tokens_indices_mask = [] + for banned_tokens_slice in banned_tokens: + banned_tokens_indices_mask.append( + [True if token in banned_tokens_slice else False for token in range(vocab_size)] + ) + + scores = set_tensor_by_indices_to_value( + scores, tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf") + ) + + return scores + + +class TFNoRepeatNGramLogitsProcessor(TFLogitsProcessor): + r""" + [`TFLogitsProcessor`] that enforces no repetition of n-grams. See + [Fairseq](https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345). + + Args: + ngram_size (`int`): + All ngrams of size `ngram_size` can only occur once. + """ + + def __init__(self, ngram_size: int): + if not isinstance(ngram_size, int) or ngram_size <= 0: + raise ValueError(f"`ngram_size` has to be a strictly positive integer, but is {ngram_size}") + self.ngram_size = ngram_size + + def calc_banned_ngram_tokens(self, prev_input_ids, num_hypos, cur_len): + # Copied from fairseq for no_repeat_ngram in beam_search + if cur_len + 1 < self.ngram_size: + # return no banned tokens if we haven't generated ngram_size tokens yet + return [[] for _ in range(num_hypos)] + generated_ngrams = [{} for _ in range(num_hypos)] + for idx in range(num_hypos): + gen_tokens = prev_input_ids[idx].numpy().tolist() + generated_ngram = generated_ngrams[idx] + for ngram in zip(*[gen_tokens[i:] for i in range(self.ngram_size)]): + prev_ngram_tuple = tuple(ngram[:-1]) + generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]] + + def _get_generated_ngrams(hypo_idx): + # Before decoding the next token, prevent decoding of ngrams that have already appeared + start_idx = cur_len + 1 - self.ngram_size + ngram_idx = tuple(prev_input_ids[hypo_idx, start_idx:cur_len].numpy().tolist()) + return generated_ngrams[hypo_idx].get(ngram_idx, []) + + banned_tokens = [_get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos)] + + return banned_tokens + + def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor) -> tf.Tensor: + + batch_size, vocab_size = scores.shape + cur_len = input_ids.shape[-1] + banned_tokens = self.calc_banned_ngram_tokens(input_ids, batch_size, cur_len) + + # create banned_tokens boolean mask + banned_tokens_indices_mask = [] + for banned_tokens_slice in banned_tokens: + banned_tokens_indices_mask.append( + [True if token in banned_tokens_slice else False for token in range(vocab_size)] + ) + + scores = set_tensor_by_indices_to_value( + scores, tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf") + ) + + return scores diff --git a/src/transformers/generation_tf_utils.py b/src/transformers/generation_tf_utils.py index d10b5817a4..b8d4746fe2 100644 --- a/src/transformers/generation_tf_utils.py +++ b/src/transformers/generation_tf_utils.py @@ -16,12 +16,20 @@ import inspect from dataclasses import dataclass -from typing import Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np import tensorflow as tf from .file_utils import ModelOutput +from .generation_tf_logits_process import ( + TFLogitsProcessorList, + TFMinLengthLogitsProcessor, + TFNoBadWordsLogitsProcessor, + TFNoRepeatNGramLogitsProcessor, + TFRepetitionPenaltyLogitsProcessor, +) +from .tf_utils import set_tensor_by_indices_to_value, shape_list from .utils import logging @@ -476,18 +484,18 @@ class TFGenerationMixin: If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible [`~file_utils.ModelOutput`] types are: - - [`~generation_utils.TFGreedySearchDecoderOnlyOutput`], - - [`~generation_utils.TFSampleDecoderOnlyOutput`], - - [`~generation_utils.TFBeamSearchDecoderOnlyOutput`], - - [`~generation_utils.TFBeamSampleDecoderOnlyOutput`] + - [`~generation_tf_utils.TFGreedySearchDecoderOnlyOutput`], + - [`~generation_tf_utils.TFSampleDecoderOnlyOutput`], + - [`~generation_tf_utils.TFBeamSearchDecoderOnlyOutput`], + - [`~generation_tf_utils.TFBeamSampleDecoderOnlyOutput`] If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible [`~file_utils.ModelOutput`] types are: - - [`~generation_utils.TFGreedySearchEncoderDecoderOutput`], - - [`~generation_utils.TFSampleEncoderDecoderOutput`], - - [`~generation_utils.TFBeamSearchEncoderDecoderOutput`], - - [`~generation_utils.TFBeamSampleEncoderDecoderOutput`] + - [`~generation_tf_utils.TFGreedySearchEncoderDecoderOutput`], + - [`~generation_tf_utils.TFSampleEncoderDecoderOutput`], + - [`~generation_tf_utils.TFBeamSearchEncoderDecoderOutput`], + - [`~generation_tf_utils.TFBeamSampleEncoderDecoderOutput`] Examples: @@ -547,6 +555,38 @@ class TFGenerationMixin: input_ids=input_ids, max_length=100, do_sample=True, bad_words_ids=bad_words_ids ) # generate sequences without allowing bad_words to be generated ```""" + num_beams = num_beams if num_beams is not None else self.config.num_beams + do_sample = do_sample if do_sample is not None else self.config.do_sample + + is_greedy_gen_mode = num_beams == 1 and do_sample is False + + if is_greedy_gen_mode: + return self._generate( + input_ids=input_ids, + max_length=max_length, + min_length=min_length, + do_sample=do_sample, + early_stopping=early_stopping, + num_beams=num_beams, + temperature=temperature, + top_k=top_k, + top_p=top_p, + repetition_penalty=repetition_penalty, + bad_words_ids=bad_words_ids, + bos_token_id=bos_token_id, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + length_penalty=length_penalty, + no_repeat_ngram_size=no_repeat_ngram_size, + num_return_sequences=num_return_sequences, + attention_mask=attention_mask, + decoder_start_token_id=decoder_start_token_id, + use_cache=use_cache, + output_scores=output_scores, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict_in_generate=return_dict_in_generate, + ) # We cannot generate if the model does not have a LM head if self.get_output_embeddings() is None: @@ -557,12 +597,11 @@ class TFGenerationMixin: max_length = max_length if max_length is not None else self.config.max_length min_length = min_length if min_length is not None else self.config.min_length - do_sample = do_sample if do_sample is not None else self.config.do_sample early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping - num_beams = num_beams if num_beams is not None else self.config.num_beams temperature = temperature if temperature is not None else self.config.temperature top_k = top_k if top_k is not None else self.config.top_k top_p = top_p if top_p is not None else self.config.top_p + repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.repetition_penalty bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id @@ -632,7 +671,7 @@ class TFGenerationMixin: bad_words_ids is None or isinstance(bad_words_ids, list) and isinstance(bad_words_ids[0], list) ), "`bad_words_ids` is either `None` or a list of lists of tokens that should not be generated" - # This block corresponds to the following line in `generation_utils`: + # This block corresponds to the following line in `generation_tf_utils`: # "input_ids = self._prepare_input_ids_for_generation(bos_token_id, model_kwargs.get("encoder_outputs"))" # with the following differences: # 1. In PT, `generate()`'s `model_kwargs` can accept `encoder_outputs`, but not the case in TF. @@ -751,8 +790,31 @@ class TFGenerationMixin: cur_len < max_length ), f"The context has {cur_len} number of tokens, but `max_length` is only {max_length}. Please make sure that `max_length` is bigger than the number of tokens, by setting either `generate(max_length=...,...)` or `config.max_length = ...`" - if num_beams > 1: - output = self._generate_beam_search( + if num_beams == 1: + return self._generate_no_beam_search( + input_ids, + cur_len=cur_len, + max_length=max_length, + min_length=min_length, + do_sample=do_sample, + temperature=temperature, + top_k=top_k, + top_p=top_p, + repetition_penalty=repetition_penalty, + no_repeat_ngram_size=no_repeat_ngram_size, + bad_words_ids=bad_words_ids, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + batch_size=effective_batch_size, + vocab_size=vocab_size, + encoder_outputs=encoder_outputs, + attention_mask=attention_mask, + use_cache=use_cache, + return_dict_in_generate=return_dict_in_generate, + **model_kwargs, + ) + else: + return self._generate_beam_search( input_ids, cur_len=cur_len, max_length=max_length, @@ -780,31 +842,6 @@ class TFGenerationMixin: return_dict_in_generate=return_dict_in_generate, **model_kwargs, ) - else: - output = self._generate_no_beam_search( - input_ids, - cur_len=cur_len, - max_length=max_length, - min_length=min_length, - do_sample=do_sample, - temperature=temperature, - top_k=top_k, - top_p=top_p, - repetition_penalty=repetition_penalty, - no_repeat_ngram_size=no_repeat_ngram_size, - bad_words_ids=bad_words_ids, - pad_token_id=pad_token_id, - eos_token_id=eos_token_id, - batch_size=effective_batch_size, - vocab_size=vocab_size, - encoder_outputs=encoder_outputs, - attention_mask=attention_mask, - use_cache=use_cache, - return_dict_in_generate=return_dict_in_generate, - **model_kwargs, - ) - - return output def _generate_no_beam_search( self, @@ -1488,6 +1525,676 @@ class TFGenerationMixin: else: return logits + def _generate( + self, + input_ids=None, + max_length=None, + min_length=None, + do_sample=None, + early_stopping=None, + num_beams=None, + temperature=None, + top_k=None, + top_p=None, + repetition_penalty=None, + bad_words_ids=None, + bos_token_id=None, + pad_token_id=None, + eos_token_id=None, + length_penalty=None, + no_repeat_ngram_size=None, + num_return_sequences=None, + attention_mask=None, + decoder_start_token_id=None, + use_cache=None, + output_scores=None, + output_attentions=None, + output_hidden_states=None, + return_dict_in_generate=None, + forced_bos_token_id=None, + forced_eos_token_id=None, + **model_kwargs, + ) -> Union[TFGreedySearchOutput, TFSampleOutput, TFBeamSearchOutput, TFBeamSampleOutput, tf.Tensor]: + r""" + Generates sequences for models with a language modeling head. The method currently supports greedy decoding, + beam-search decoding, sampling with temperature, sampling with top-k or nucleus sampling. + + Adapted in part from [Facebook's XLM beam search + code](https://github.com/facebookresearch/XLM/blob/9e6f6814d17be4fe5b15f2e6c43eb2b2d76daeb4/src/model/transformer.py#L529). + + Apart from `input_ids` and `attention_mask`, all the arguments below will default to the value of the attribute + of the same name inside the [`PretrainedConfig`] of the model. The default values indicated are the default + values of those config. + + Most of these parameters are explained in more detail in [this blog + post](https://huggingface.co/blog/how-to-generate). + + Parameters: + + input_ids (`tf.Tensor` of `dtype=tf.int32` and shape `(batch_size, sequence_length)`, *optional*): + The sequence used as a prompt for the generation. If `None` the method initializes it with + `bos_token_id` and a batch size of 1. + max_length (`int`, *optional*, defaults to 20): + The maximum length of the sequence to be generated. + min_length (`int`, *optional*, defaults to 10): + The minimum length of the sequence to be generated. + do_sample (`bool`, *optional*, defaults to `False`): + Whether or not to use sampling ; use greedy decoding otherwise. + early_stopping (`bool`, *optional*, defaults to `False`): + Whether to stop the beam search when at least `num_beams` sentences are finished per batch or not. + num_beams (`int`, *optional*, defaults to 1): + Number of beams for beam search. 1 means no beam search. + temperature (`float`, *optional*, defaults to 1.0): + The value used to module the next token probabilities. + top_k (`int`, *optional*, defaults to 50): + The number of highest probability vocabulary tokens to keep for top-k-filtering. + top_p (`float`, *optional*, defaults to 1.0): + If set to float < 1, only the most probable tokens with probabilities that add up to `top_p` or higher + are kept for generation. + repetition_penalty (`float`, *optional*, defaults to 1.0): + The parameter for repetition penalty. 1.0 means no penalty. See [this + paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. + pad_token_id (`int`, *optional*): + The id of the *padding* token. + bos_token_id (`int`, *optional*): + The id of the *beginning-of-sequence* token. + eos_token_id (`int`, *optional*): + The id of the *end-of-sequence* token. + length_penalty (`float`, *optional*, defaults to 1.0): + Exponential penalty to the length. 1.0 means no penalty. + + Set to values < 1.0 in order to encourage the model to generate shorter sequences, to a value > 1.0 in + order to encourage the model to produce longer sequences. + no_repeat_ngram_size (`int`, *optional*, defaults to 0): + If set to int > 0, all ngrams of that size can only occur once. + bad_words_ids(`List[int]`, *optional*): + List of token ids that are not allowed to be generated. In order to get the tokens of the words that + should not appear in the generated text, use `tokenizer.encode(bad_word, add_prefix_space=True)`. + num_return_sequences(`int`, *optional*, defaults to 1): + The number of independently computed returned sequences for each element in the batch. + attention_mask (`tf.Tensor` of `dtype=tf.int32` and shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values are in `[0, 1]`, 1 for tokens + that are not masked, and 0 for masked tokens. + + If not provided, will default to a tensor the same shape as `input_ids` that masks the pad token. + + [What are attention masks?](../glossary#attention-mask) + decoder_start_token_id (`int`, *optional*): + If an encoder-decoder model starts decoding with a different token than *bos*, the id of that token. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should use the past last key/values attentions (if applicable to the model) to + speed up decoding. + output_attentions (`bool`, *optional*, defaults to `False`): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more details. + output_hidden_states (`bool`, *optional*, defaults to `False`): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more details. + output_scores (`bool`, *optional*, defaults to `False`): + Whether or not to return the prediction scores. See `scores` under returned tensors for more details. + return_dict_in_generate (`bool`, *optional*, defaults to `False`): + Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. + forced_bos_token_id (`int`, *optional*): + The id of the token to force as the first generated token after the `decoder_start_token_id`. Useful + for multilingual models like [mBART](../model_doc/mbart) where the first generated token needs to be + the target language token. + forced_eos_token_id (`int`, *optional*): + The id of the token to force as the last generated token when `max_length` is reached. + model_specific_kwargs: + Additional model specific kwargs will be forwarded to the `forward` function of the model. + + Return: + [`~file_utils.ModelOutput`] or `tf.Tensor`: A [`~file_utils.ModelOutput`] (if + `return_dict_in_generate=True` or when `config.return_dict_in_generate=True`) or a `tf.Tensor`. + + If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible + [`~file_utils.ModelOutput`] types are: + + - [`~generation_tf_utils.TFGreedySearchDecoderOnlyOutput`], + - [`~generation_tf_utils.TFSampleDecoderOnlyOutput`], + - [`~generation_tf_utils.TFBeamSearchDecoderOnlyOutput`], + - [`~generation_tf_utils.TFBeamSampleDecoderOnlyOutput`] + + If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible + [`~file_utils.ModelOutput`] types are: + + - [`~generation_tf_utils.TFGreedySearchEncoderDecoderOutput`], + - [`~generation_tf_utils.TFSampleEncoderDecoderOutput`], + - [`~generation_tf_utils.TFBeamSearchEncoderDecoderOutput`], + - [`~generation_tf_utils.TFBeamSampleEncoderDecoderOutput`] + + Examples: + + ```python + tokenizer = AutoTokenizer.from_pretrained("distilgpt2") # Initialize tokenizer + model = TFAutoModelWithLMHead.from_pretrained("distilgpt2") + # Greedy decoding + outputs = model.generate(max_length=40) + print(f"Generated: {tokenizer.decode(outputs[0], skip_special_tokens=True)}") + + tokenizer = AutoTokenizer.from_pretrained("openai-gpt") + model = TFAutoModelWithLMHead.from_pretrained("openai-gpt") + input_context = "The dog" + input_ids = tokenizer.encode(input_context, return_tensors="tf") # encode input context + # Generate 3 independent sequences using beam search decoding (5 beams) with sampling from initial context 'The dog' + outputs = model.generate(input_ids=input_ids, num_beams=5, num_return_sequences=3, temperature=1.5) + # 3 output sequences were generated + for i in range(3): + print(f"Generated {i}: {tokenizer.decode(outputs[i], skip_special_tokens=True)}") + + tokenizer = AutoTokenizer.from_pretrained("distilgpt2") + model = TFAutoModelWithLMHead.from_pretrained("distilgpt2") + input_context = "The dog" + input_ids = tokenizer.encode(input_context, return_tensors="tf") + # Generate 3 candidates using sampling + outputs = model.generate( + input_ids=input_ids, max_length=40, temperature=0.7, num_return_sequences=3, do_sample=True + ) + # 3 output sequences were generated + for i in range(3): + print(f"Generated {i}: {tokenizer.decode(outputs[i], skip_special_tokens=True)}") + + tokenizer = AutoTokenizer.from_pretrained("ctrl") + model = TFAutoModelWithLMHead.from_pretrained("ctrl") + # "Legal" is one of the control codes for ctrl + input_context = "Legal My neighbor is" + input_ids = tokenizer.encode(input_context, return_tensors="tf") + outputs = model.generate(input_ids=input_ids, max_length=50, temperature=0.7, repetition_penalty=1.2) + print(f"Generated: {tokenizer.decode(outputs[0], skip_special_tokens=True)}") + + tokenizer = AutoTokenizer.from_pretrained("gpt2") + model = TFAutoModelWithLMHead.from_pretrained("gpt2") + input_context = "My cute dog" + bad_words_ids = [ + tokenizer.encode(bad_word, add_prefix_space=True) for bad_word in ["idiot", "stupid", "shut up"] + ] + input_ids = tokenizer.encode(input_context, return_tensors="tf") + # generate sequences without allowing bad_words to be generated + outputs = model.generate(input_ids=input_ids, max_length=100, do_sample=True, bad_words_ids=bad_words_ids) + ```""" + # 1. Set generation parameters if not already defined + max_length = max_length if max_length is not None else self.config.max_length + min_length = min_length if min_length is not None else self.config.min_length + early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping + + bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id + pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id + eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id + + output_scores = output_scores if output_scores is not None else self.config.output_scores + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict_in_generate = ( + return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate + ) + + num_beams = num_beams if num_beams is not None else self.config.num_beams + do_sample = do_sample if do_sample is not None else self.config.do_sample + num_return_sequences = ( + num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences + ) + + if pad_token_id is None and eos_token_id is not None: + logger.warning(f"Setting `pad_token_id` to {eos_token_id} (first `eos_token_id`) to generate sequence") + pad_token_id = eos_token_id + + # 2. Define model inputs + input_ids = self._prepare_model_inputs(input_ids, bos_token_id) + # inputs_ids now has to be defined and cannot be None anymore + batch_size = input_ids.shape[0] + + # 3. Prepare other model kwargs + model_kwargs["output_attentions"] = output_attentions + model_kwargs["output_hidden_states"] = output_hidden_states + model_kwargs["use_cache"] = use_cache + + requires_attention_mask = "encoder_outputs" not in model_kwargs + + if model_kwargs.get("attention_mask", None) is None and requires_attention_mask: + model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(input_ids, pad_token_id) + + if self.config.is_encoder_decoder: + # if model is encoder decoder model, we create encoder_outputs and add to `model_kwargs` + model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation( + input_ids, return_dict_in_generate, model_kwargs + ) + + # TODO(Patrick) - ugly `past`/`encoder_output` hack here which requires a bigger + # refactor of all generation models in TF. `past` should be + # optional everywhere and not be set equal to encoder_outputs + model_kwargs["past"] = model_kwargs.get("encoder_outputs")[:1] if self.config.is_encoder_decoder else None + + # 4. Prepare `input_ids` which will be used for auto-regressive generation + if self.config.is_encoder_decoder: + # if encoder-decoder then `input_ids` come from `decoder_start_token_id` + input_ids = self._prepare_decoder_input_ids_for_generation( + batch_size, + decoder_start_token_id=decoder_start_token_id, + bos_token_id=bos_token_id, + model_kwargs=model_kwargs, + ) + + if input_ids.shape[-1] >= max_length: + raise ValueError( + f"The context has {input_ids.shape[-1]} number of tokens, " + f"but `max_length` is only {max_length}. " + "Please make sure that `max_length` is bigger than the number of tokens, " + "by setting either `generate(max_length=...,...)` or `config.max_length = ...`" + ) + + # 5. determine generation mode + # TODO(Matt, Joao, Patrick) - add more use cases here + is_greedy_gen_mode = (num_beams == 1) and do_sample is False + + # 6. prepare distribution pre_processing samplers + logits_processor = self._get_logits_processor( + repetition_penalty=repetition_penalty, + no_repeat_ngram_size=no_repeat_ngram_size, + bad_words_ids=bad_words_ids, + min_length=min_length, + eos_token_id=eos_token_id, + ) + + # 7. go into different generation modes + if is_greedy_gen_mode: + if num_return_sequences > 1: + raise ValueError( + f"num_return_sequences has to be 1, but is {num_return_sequences} when doing greedy search." + ) + + # 8. run greedy search + return self.greedy_search( + input_ids, + max_length=max_length, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + logits_processor=logits_processor, + output_scores=output_scores, + return_dict_in_generate=return_dict_in_generate, + **model_kwargs, + ) + + # TODO(Matt, Joao, Patrick) - add more sub-generation methods here + + def _prepare_attention_mask_for_generation( + self, + input_ids: tf.Tensor, + pad_token_id: int, + ) -> tf.Tensor: + # prepare `attention_mask` if not passed + if (pad_token_id is not None) and (pad_token_id in input_ids.numpy()): + return tf.cast(tf.math.not_equal(input_ids, pad_token_id), dtype=tf.int32) + else: + return tf.ones(input_ids.shape[:2], dtype=tf.int32) + + def _prepare_encoder_decoder_kwargs_for_generation( + self, input_ids: tf.Tensor, return_dict_in_generate, model_kwargs + ) -> Dict[str, Any]: + # TODO(Patrick) - remove `return_dict_in_generate` flag input once `past`/`encoder_outputs` + # is cleaned + + # get encoder and store encoder outputs + encoder = self.get_encoder() + + # prepare encoder args and encoder kwargs from model kwargs + irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"] + encoder_kwargs = { + argument: value + for argument, value in model_kwargs.items() + if not any(argument.startswith(p) for p in irrelevant_prefix) + } + + # vision models don't use `attention_mask`. + signature = dict(inspect.signature(encoder.call).parameters) + if "attention_mask" not in signature: + encoder_kwargs.pop("attention_mask") + + encoder_outputs = encoder(input_ids, **encoder_kwargs) + + model_kwargs["encoder_outputs"] = encoder_outputs + + # TODO(Patrick): `encoder_outputs`, `past` hack. Currently, `encoder_attentions` and + # `encoder_hidden_states` have to be seperated from encoder_outputs and passed + # under other names because of `encoder_outputs`, `past` hack. Need to clean-up + # all encoder-decoder prepare_inputs_for_generation method to clean this + if return_dict_in_generate: + model_kwargs["encoder_attentions"] = encoder_outputs.get("attentions", None) + model_kwargs["encoder_hidden_states"] = encoder_outputs.get("hidden_states", None) + + return model_kwargs + + def _prepare_decoder_input_ids_for_generation( + self, + batch_size: int, + decoder_start_token_id: int = None, + bos_token_id: int = None, + model_kwargs: Optional[Dict[str, tf.Tensor]] = None, + ) -> tf.Tensor: + + # prepare `input_ids` for decoder if model is encoder-decoder + if model_kwargs is not None and "decoder_input_ids" in model_kwargs: + return model_kwargs.pop("decoder_input_ids") + else: + decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id) + return tf.ones((batch_size, 1), dtype=tf.int32) * decoder_start_token_id + + def _get_decoder_start_token_id(self, decoder_start_token_id: int = None, bos_token_id: int = None) -> int: + # retrieve decoder_start_token_id for encoder-decoder models + # fall back to bos_token_id if necessary + decoder_start_token_id = ( + decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id + ) + bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id + + if decoder_start_token_id is not None: + return decoder_start_token_id + elif ( + hasattr(self.config, "decoder") + and hasattr(self.config.decoder, "decoder_start_token_id") + and self.config.decoder.decoder_start_token_id is not None + ): + return self.config.decoder.decoder_start_token_id + elif bos_token_id is not None: + return bos_token_id + elif ( + hasattr(self.config, "decoder") + and hasattr(self.config.decoder, "bos_token_id") + and self.config.decoder.bos_token_id is not None + ): + return self.config.decoder.bos_token_id + raise ValueError( + "`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation." + ) + + def _prepare_model_inputs(self, inputs: Optional[tf.Tensor] = None, bos_token_id: Optional[int] = None): + # TODO(Patrick) - adapt this function when making `generate` more flexible + # for all kinds of input types + if inputs is None: + # if no `inputs` are passed create prompt of size (1,1) filled with BOS token + if not isinstance(bos_token_id, int) or bos_token_id < 0: + raise ValueError( + "you should either supply a context to complete as `input_ids` input " + "or a `bos_token_id` (integer >= 0) as a first token to start the generation." + ) + return tf.cast(tf.fill((1, 1), bos_token_id), dtype=tf.int32) + + return inputs + + def _update_model_kwargs_for_generation( + self, outputs: ModelOutput, model_kwargs: Dict[str, Any], is_encoder_decoder: bool = False + ) -> Dict[str, Any]: + # update past + if self._use_cache(outputs, model_kwargs["use_cache"]): + # TODO(Patrick): `past`/`encoder_outputs` hack. This should be + # removed when cleaning up the encoder-decoder models + # if model has past, then set the past variable to speed up decoding + # make this method static then as well + model_kwargs["past"] = outputs[1] + elif "past_key_values" in outputs: + model_kwargs["past"] = outputs.past_key_values + elif "mems" in outputs: + model_kwargs["past"] = outputs.mems + elif "past_buckets_states" in outputs: + model_kwargs["past"] = outputs.past_buckets_states + elif "past" in model_kwargs: + # TODO(Patrick) `past`/`encoder_outputs` hack. + # removed when cleaning up the encoder-decoder models. + # The line should not be necessary. + pass + else: + model_kwargs["past"] = None + + # update attention mask + if not is_encoder_decoder: + if "attention_mask" in model_kwargs: + attention_mask = model_kwargs["attention_mask"] + model_kwargs["attention_mask"] = tf.concat( + [attention_mask, tf.ones((shape_list(attention_mask)[0], 1), dtype=tf.int32)], axis=-1 + ) + + return model_kwargs + + def _get_logits_processor( + self, + repetition_penalty: float, + no_repeat_ngram_size: int, + bad_words_ids: List[List[int]], + min_length: int, + eos_token_id: int, + ) -> TFLogitsProcessorList: + """ + This class returns a [`TFLogitsProcessorList`] list object that contains all relevant [`TFLogitsProcessor`] + instances used to modify the scores of the language model head. + """ + processors = TFLogitsProcessorList() + + repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.repetition_penalty + no_repeat_ngram_size = ( + no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size + ) + bad_words_ids = bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids + eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id + + # instantiate processors list + if repetition_penalty is not None and repetition_penalty != 1.0: + processors.append(TFRepetitionPenaltyLogitsProcessor(penalty=repetition_penalty)) + if no_repeat_ngram_size is not None and no_repeat_ngram_size > 0: + processors.append(TFNoRepeatNGramLogitsProcessor(no_repeat_ngram_size)) + if bad_words_ids is not None: + processors.append(TFNoBadWordsLogitsProcessor(bad_words_ids, eos_token_id)) + if min_length is not None and eos_token_id is not None and min_length > -1: + processors.append(TFMinLengthLogitsProcessor(min_length, eos_token_id)) + + return processors + + def greedy_search( + self, + input_ids: tf.Tensor, + max_length: Optional[int] = None, + pad_token_id: Optional[int] = None, + eos_token_id: Optional[int] = None, + logits_processor: Optional[TFLogitsProcessorList] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_scores: Optional[bool] = None, + return_dict_in_generate: Optional[bool] = None, + **model_kwargs, + ) -> Union[TFGreedySearchOutput, tf.Tensor]: + r""" + Generates sequences for models with a language modeling head using greedy decoding. + + Parameters: + + input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`): + The sequence used as a prompt for the generation. + logits_processor (`TFLogitsProcessorList`, *optional*): + An instance of [`TFLogitsProcessorList`]. List of instances of class derived from [`TFLogitsProcessor`] + used to modify the prediction scores of the language modeling head applied at each generation step. + max_length (`int`, *optional*, defaults to 20): + The maximum length of the sequence to be generated. + pad_token_id (`int`, *optional*): + The id of the *padding* token. + eos_token_id (`int`, *optional*): + The id of the *end-of-sequence* token. + output_attentions (`bool`, *optional*, defaults to `False`): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more details. + output_hidden_states (`bool`, *optional*, defaults to `False`): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more details. + output_scores (`bool`, *optional*, defaults to `False`): + Whether or not to return the prediction scores. See `scores` under returned tensors for more details. + return_dict_in_generate (`bool`, *optional*, defaults to `False`): + Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. + model_kwargs: + Additional model specific keyword arguments will be forwarded to the `forward` function of the model. + If model is an encoder-decoder model the kwargs should include `encoder_outputs`. + + Return: + [`~generation_tf_utils.TFGreedySearchDecoderOnlyOutput`], + [`~generation_tf_utils.TFGreedySearchEncoderDecoderOutput`] or `tf.Tensor`: A `tf.Tensor` containing the + generated tokens (default behaviour) or a [`~generation_tf_utils.TFGreedySearchDecoderOnlyOutput`] if + `model.config.is_encoder_decoder=False` and `return_dict_in_generate=True` or a + [`~generation_tf_utils.TFGreedySearchEncoderDecoderOutput`] if `model.config.is_encoder_decoder=True`. + + Examples: + + ```python + >>> from transformers import ( + ... TFAutoTokenizer, + ... TFAutoModelForCausalLM, + ... TFLogitsProcessorList, + ... TFMinLengthLogitsProcessor, + ... ) + + >>> tokenizer = TFAutoTokenizer.from_pretrained("gpt2") + >>> model = TFAutoModelForCausalLM.from_pretrained("gpt2") + + >>> # set pad_token_id to eos_token_id because GPT2 does not have a EOS token + >>> model.config.pad_token_id = model.config.eos_token_id + + >>> input_prompt = "Today is a beautiful day, and" + >>> input_ids = tokenizer(input_prompt, return_tensors="tf").input_ids + + >>> # instantiate logits processors + >>> logits_processor = TFLogitsProcessorList( + ... [ + ... TFMinLengthLogitsProcessor(15, eos_token_id=model.config.eos_token_id), + ... ] + ... ) + + >>> outputs = model.greedy_search(input_ids, logits_processor=logits_processor) + + >>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True)) + ```""" + # init values + logits_processor = logits_processor if logits_processor is not None else TFLogitsProcessorList() + + pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id + eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id + output_scores = output_scores if output_scores is not None else self.config.output_scores + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict_in_generate = ( + return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate + ) + + # init attention / hidden states / scores tuples + scores = () if (return_dict_in_generate and output_scores) else None + decoder_attentions = () if (return_dict_in_generate and output_attentions) else None + cross_attentions = () if (return_dict_in_generate and output_attentions) else None + decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None + + # TODO(Patrick): `encoder_outputs`, `past` hack. Currently T5, Bart expect `encoder_outputs` + # to be wrapped into `past` variable. Tis is a bad design and needs + # to be updated. + # Remove the following lines when updating all encoder-decoder models + encoder_outputs = model_kwargs.pop("encoder_outputs", None) + + # if model is an encoder-decoder, retrieve encoder attention weights and hidden states + if return_dict_in_generate and self.config.is_encoder_decoder: + encoder_attentions = encoder_outputs.get("attentions") if output_attentions else None + encoder_hidden_states = encoder_outputs.get("hidden_states") if output_hidden_states else None + + # keep track of which sequences are already finished + unfinished_sequences = tf.ones_like(input_ids[:, 0]) + cur_len = input_ids.shape[-1] + + while cur_len < max_length: + # TODO(Patrick): remove following line by cleaning up `prepare_inputs_for_generation` + # in all models + model_kwargs["use_cache"] = None if "use_cache" not in model_kwargs else model_kwargs["use_cache"] + + # prepare model inputs + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + + # forward pass to get next token + outputs = self( + **model_inputs, + return_dict=True, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + next_token_logits = outputs.logits[:, -1, :] + + # Store scores, attentions and hidden_states when required + if return_dict_in_generate: + if output_scores: + scores += (next_token_logits,) + if output_attentions: + decoder_attentions += ( + (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) + ) + if self.config.is_encoder_decoder: + cross_attentions += (outputs.cross_attentions,) + + if output_hidden_states: + decoder_hidden_states += ( + (outputs.decoder_hidden_states,) + if self.config.is_encoder_decoder + else (outputs.hidden_states,) + ) + + # pre-process distribution + next_tokens_scores = logits_processor(input_ids, next_token_logits) + + # argmax + next_tokens = tf.cast(tf.argmax(next_tokens_scores, axis=-1), tf.int32) + + # finished sentences should have their next token be a padding token + if eos_token_id is not None: + if pad_token_id is None: + raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") + next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) + + # update generated ids, model inputs, and length for next step + input_ids = tf.concat([input_ids, next_tokens[:, None]], axis=-1) + model_kwargs = self._update_model_kwargs_for_generation( + outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder + ) + cur_len = cur_len + 1 + + # if eos_token was found in one sentence, set sentence to finished + if eos_token_id is not None: + eos_in_sents = next_tokens == eos_token_id + # if sentence is unfinished and the token to add is eos + is_sents_unfinished_and_token_to_add_is_eos = tf.math.multiply( + unfinished_sequences, tf.cast(eos_in_sents, tf.int32) + ) + + # unfinished_sequences is set to zero if eos in sentence + unfinished_sequences -= is_sents_unfinished_and_token_to_add_is_eos + + # stop when each sentence is finished, or if we exceed the maximum length + if tf.math.reduce_max(unfinished_sequences) == 0: + break + + if return_dict_in_generate: + if self.config.is_encoder_decoder: + return TFGreedySearchEncoderDecoderOutput( + sequences=input_ids, + scores=scores, + encoder_attentions=encoder_attentions, + encoder_hidden_states=encoder_hidden_states, + decoder_attentions=decoder_attentions, + cross_attentions=cross_attentions, + decoder_hidden_states=decoder_hidden_states, + ) + else: + return TFGreedySearchDecoderOnlyOutput( + sequences=input_ids, + scores=scores, + attentions=decoder_attentions, + hidden_states=decoder_hidden_states, + ) + else: + return input_ids + def _create_next_token_logits_penalties(input_ids, logits, repetition_penalty): # create logit penalties for already seen input_ids @@ -1628,12 +2335,6 @@ def scatter_values_on_batch_indices(values, batch_indices): return tf.scatter_nd(pair_indices, tf.reshape(values, [-1]), shape) -def set_tensor_by_indices_to_value(tensor, indices, value): - # create value_tensor since tensor value assignment is not possible in TF - value_tensor = tf.zeros_like(tensor) + value - return tf.where(indices, value_tensor, tensor) - - def sample_without_replacement(logits, num_samples): """ categorical sampling without replacement is currently not implemented the gumbel-max trick will do for now see @@ -1644,13 +2345,6 @@ def sample_without_replacement(logits, num_samples): return indices -def shape_list(x): - """Deal with dynamic shape in tensorflow cleanly.""" - static = x.shape.as_list() - dynamic = tf.shape(x) - return [dynamic[i] if s is None else s for i, s in enumerate(static)] - - class BeamHypotheses(object): def __init__(self, num_beams, max_length, length_penalty, early_stopping): """ diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index e34b7aa2ca..79d71bcb86 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -54,6 +54,7 @@ from .file_utils import ( ) from .generation_tf_utils import TFGenerationMixin from .modeling_tf_outputs import TFSeq2SeqLMOutput +from .tf_utils import shape_list from .tokenization_utils_base import BatchEncoding from .utils import logging @@ -2041,29 +2042,6 @@ class TFSequenceSummary(tf.keras.layers.Layer): cls._auto_class = auto_class -def shape_list(tensor: Union[tf.Tensor, np.ndarray]) -> List[int]: - """ - Deal with dynamic shape in tensorflow cleanly. - - Args: - tensor (`tf.Tensor` or `np.ndarray`): The tensor we want the shape of. - - Returns: - `List[int]`: The shape of the tensor as a list. - """ - if isinstance(tensor, np.ndarray): - return list(tensor.shape) - - dynamic = tf.shape(tensor) - - if tensor.shape == tf.TensorShape(None): - return dynamic - - static = tensor.shape.as_list() - - return [dynamic[i] if s is None else s for i, s in enumerate(static)] - - def get_initializer(initializer_range: float = 0.02) -> tf.initializers.TruncatedNormal: """ Creates a `tf.initializers.TruncatedNormal` with the given range. diff --git a/src/transformers/models/albert/modeling_tf_albert.py b/src/transformers/models/albert/modeling_tf_albert.py index f2659e817a..42f1e5b34d 100644 --- a/src/transformers/models/albert/modeling_tf_albert.py +++ b/src/transformers/models/albert/modeling_tf_albert.py @@ -51,8 +51,8 @@ from ...modeling_tf_utils import ( get_initializer, input_processing, keras_serializable, - shape_list, ) +from ...tf_utils import shape_list from ...utils import logging from .configuration_albert import AlbertConfig diff --git a/src/transformers/models/bart/modeling_tf_bart.py b/src/transformers/models/bart/modeling_tf_bart.py index b9abc647ab..058fdb99f2 100644 --- a/src/transformers/models/bart/modeling_tf_bart.py +++ b/src/transformers/models/bart/modeling_tf_bart.py @@ -44,8 +44,8 @@ from ...modeling_tf_utils import ( TFWrappedEmbeddings, input_processing, keras_serializable, - shape_list, ) +from ...tf_utils import shape_list from ...utils import logging from .configuration_bart import BartConfig diff --git a/src/transformers/models/bert/modeling_tf_bert.py b/src/transformers/models/bert/modeling_tf_bert.py index 7d7d431c7e..bf5ddb365b 100644 --- a/src/transformers/models/bert/modeling_tf_bert.py +++ b/src/transformers/models/bert/modeling_tf_bert.py @@ -57,8 +57,8 @@ from ...modeling_tf_utils import ( get_initializer, input_processing, keras_serializable, - shape_list, ) +from ...tf_utils import shape_list from ...utils import logging from .configuration_bert import BertConfig diff --git a/src/transformers/models/blenderbot/modeling_tf_blenderbot.py b/src/transformers/models/blenderbot/modeling_tf_blenderbot.py index 6d50492062..65135a1d07 100644 --- a/src/transformers/models/blenderbot/modeling_tf_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_tf_blenderbot.py @@ -46,8 +46,8 @@ from ...modeling_tf_utils import ( TFWrappedEmbeddings, input_processing, keras_serializable, - shape_list, ) +from ...tf_utils import shape_list from ...utils import logging from .configuration_blenderbot import BlenderbotConfig diff --git a/src/transformers/models/blenderbot_small/modeling_tf_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_tf_blenderbot_small.py index fdf0c63c0a..0243030a43 100644 --- a/src/transformers/models/blenderbot_small/modeling_tf_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_tf_blenderbot_small.py @@ -44,8 +44,8 @@ from ...modeling_tf_utils import ( TFWrappedEmbeddings, input_processing, keras_serializable, - shape_list, ) +from ...tf_utils import shape_list from ...utils import logging from .configuration_blenderbot_small import BlenderbotSmallConfig diff --git a/src/transformers/models/clip/modeling_tf_clip.py b/src/transformers/models/clip/modeling_tf_clip.py index 3a1621ba9d..4902248b25 100644 --- a/src/transformers/models/clip/modeling_tf_clip.py +++ b/src/transformers/models/clip/modeling_tf_clip.py @@ -39,8 +39,8 @@ from ...modeling_tf_utils import ( get_initializer, input_processing, keras_serializable, - shape_list, ) +from ...tf_utils import shape_list from ...utils import logging from .configuration_clip import CLIPConfig, CLIPTextConfig, CLIPVisionConfig diff --git a/src/transformers/models/convbert/modeling_tf_convbert.py b/src/transformers/models/convbert/modeling_tf_convbert.py index 84967b5fba..0c4d265dcd 100644 --- a/src/transformers/models/convbert/modeling_tf_convbert.py +++ b/src/transformers/models/convbert/modeling_tf_convbert.py @@ -43,8 +43,8 @@ from ...modeling_tf_utils import ( get_initializer, input_processing, keras_serializable, - shape_list, ) +from ...tf_utils import shape_list from ...utils import logging from .configuration_convbert import ConvBertConfig diff --git a/src/transformers/models/ctrl/modeling_tf_ctrl.py b/src/transformers/models/ctrl/modeling_tf_ctrl.py index acfce53c8a..c72448310a 100644 --- a/src/transformers/models/ctrl/modeling_tf_ctrl.py +++ b/src/transformers/models/ctrl/modeling_tf_ctrl.py @@ -30,8 +30,8 @@ from ...modeling_tf_utils import ( get_initializer, input_processing, keras_serializable, - shape_list, ) +from ...tf_utils import shape_list from ...utils import logging from .configuration_ctrl import CTRLConfig diff --git a/src/transformers/models/deberta/modeling_tf_deberta.py b/src/transformers/models/deberta/modeling_tf_deberta.py index 25a6c07d42..0d36de4895 100644 --- a/src/transformers/models/deberta/modeling_tf_deberta.py +++ b/src/transformers/models/deberta/modeling_tf_deberta.py @@ -39,8 +39,8 @@ from ...modeling_tf_utils import ( TFTokenClassificationLoss, get_initializer, input_processing, - shape_list, ) +from ...tf_utils import shape_list from ...utils import logging from .configuration_deberta import DebertaConfig diff --git a/src/transformers/models/deberta_v2/modeling_tf_deberta_v2.py b/src/transformers/models/deberta_v2/modeling_tf_deberta_v2.py index 1a8f8c94ba..445cb76256 100644 --- a/src/transformers/models/deberta_v2/modeling_tf_deberta_v2.py +++ b/src/transformers/models/deberta_v2/modeling_tf_deberta_v2.py @@ -38,8 +38,8 @@ from ...modeling_tf_utils import ( TFTokenClassificationLoss, get_initializer, input_processing, - shape_list, ) +from ...tf_utils import shape_list from ...utils import logging from .configuration_deberta_v2 import DebertaV2Config diff --git a/src/transformers/models/distilbert/modeling_tf_distilbert.py b/src/transformers/models/distilbert/modeling_tf_distilbert.py index 05da8b3061..86a814a749 100644 --- a/src/transformers/models/distilbert/modeling_tf_distilbert.py +++ b/src/transformers/models/distilbert/modeling_tf_distilbert.py @@ -45,8 +45,8 @@ from ...modeling_tf_utils import ( get_initializer, input_processing, keras_serializable, - shape_list, ) +from ...tf_utils import shape_list from ...utils import logging from .configuration_distilbert import DistilBertConfig diff --git a/src/transformers/models/electra/modeling_tf_electra.py b/src/transformers/models/electra/modeling_tf_electra.py index f24b003b60..68c639de91 100644 --- a/src/transformers/models/electra/modeling_tf_electra.py +++ b/src/transformers/models/electra/modeling_tf_electra.py @@ -50,8 +50,8 @@ from ...modeling_tf_utils import ( get_initializer, input_processing, keras_serializable, - shape_list, ) +from ...tf_utils import shape_list from ...utils import logging from .configuration_electra import ElectraConfig diff --git a/src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py b/src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py index 8ba4ae31b8..a2668b75b1 100644 --- a/src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py +++ b/src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py @@ -30,13 +30,8 @@ from ...file_utils import ( replace_return_docstrings, ) from ...modeling_tf_outputs import TFBaseModelOutput, TFSeq2SeqLMOutput -from ...modeling_tf_utils import ( - TFCausalLanguageModelingLoss, - TFPreTrainedModel, - get_initializer, - input_processing, - shape_list, -) +from ...modeling_tf_utils import TFCausalLanguageModelingLoss, TFPreTrainedModel, get_initializer, input_processing +from ...tf_utils import shape_list from ...utils import logging from ..auto.configuration_auto import AutoConfig from ..auto.modeling_tf_auto import TFAutoModel, TFAutoModelForCausalLM diff --git a/src/transformers/models/flaubert/modeling_tf_flaubert.py b/src/transformers/models/flaubert/modeling_tf_flaubert.py index 87c1c7e6b0..c681277a80 100644 --- a/src/transformers/models/flaubert/modeling_tf_flaubert.py +++ b/src/transformers/models/flaubert/modeling_tf_flaubert.py @@ -38,8 +38,8 @@ from ...modeling_tf_utils import ( get_initializer, input_processing, keras_serializable, - shape_list, ) +from ...tf_utils import shape_list from ...utils import logging from ..xlm.modeling_tf_xlm import ( TFXLMForMultipleChoice, diff --git a/src/transformers/models/funnel/modeling_tf_funnel.py b/src/transformers/models/funnel/modeling_tf_funnel.py index b3d9a8506e..9b4b6e7083 100644 --- a/src/transformers/models/funnel/modeling_tf_funnel.py +++ b/src/transformers/models/funnel/modeling_tf_funnel.py @@ -47,8 +47,8 @@ from ...modeling_tf_utils import ( get_initializer, input_processing, keras_serializable, - shape_list, ) +from ...tf_utils import shape_list from ...utils import logging from .configuration_funnel import FunnelConfig diff --git a/src/transformers/models/gpt2/modeling_tf_gpt2.py b/src/transformers/models/gpt2/modeling_tf_gpt2.py index ab32cc0e83..d4939594d5 100644 --- a/src/transformers/models/gpt2/modeling_tf_gpt2.py +++ b/src/transformers/models/gpt2/modeling_tf_gpt2.py @@ -44,8 +44,8 @@ from ...modeling_tf_utils import ( get_initializer, input_processing, keras_serializable, - shape_list, ) +from ...tf_utils import shape_list from ...utils import logging from .configuration_gpt2 import GPT2Config diff --git a/src/transformers/models/hubert/modeling_tf_hubert.py b/src/transformers/models/hubert/modeling_tf_hubert.py index 548ea5e385..936f2ab0dc 100644 --- a/src/transformers/models/hubert/modeling_tf_hubert.py +++ b/src/transformers/models/hubert/modeling_tf_hubert.py @@ -28,13 +28,8 @@ from ...file_utils import ( replace_return_docstrings, ) from ...modeling_tf_outputs import TFBaseModelOutput, TFCausalLMOutput -from ...modeling_tf_utils import ( - TFPreTrainedModel, - booleans_processing, - get_initializer, - keras_serializable, - shape_list, -) +from ...modeling_tf_utils import TFPreTrainedModel, booleans_processing, get_initializer, keras_serializable +from ...tf_utils import shape_list from ...tokenization_utils_base import BatchEncoding from ...utils import logging from .configuration_hubert import HubertConfig diff --git a/src/transformers/models/layoutlm/modeling_tf_layoutlm.py b/src/transformers/models/layoutlm/modeling_tf_layoutlm.py index dbc9b21b0b..6f30883500 100644 --- a/src/transformers/models/layoutlm/modeling_tf_layoutlm.py +++ b/src/transformers/models/layoutlm/modeling_tf_layoutlm.py @@ -39,8 +39,8 @@ from ...modeling_tf_utils import ( get_initializer, input_processing, keras_serializable, - shape_list, ) +from ...tf_utils import shape_list from ...utils import logging from .configuration_layoutlm import LayoutLMConfig diff --git a/src/transformers/models/led/modeling_tf_led.py b/src/transformers/models/led/modeling_tf_led.py index 924a62f7d9..e282db0e81 100644 --- a/src/transformers/models/led/modeling_tf_led.py +++ b/src/transformers/models/led/modeling_tf_led.py @@ -39,8 +39,8 @@ from ...modeling_tf_utils import ( get_initializer, input_processing, keras_serializable, - shape_list, ) +from ...tf_utils import shape_list from ...utils import logging from .configuration_led import LEDConfig diff --git a/src/transformers/models/longformer/modeling_tf_longformer.py b/src/transformers/models/longformer/modeling_tf_longformer.py index da34d11b80..458133a9b4 100644 --- a/src/transformers/models/longformer/modeling_tf_longformer.py +++ b/src/transformers/models/longformer/modeling_tf_longformer.py @@ -38,8 +38,8 @@ from ...modeling_tf_utils import ( get_initializer, input_processing, keras_serializable, - shape_list, ) +from ...tf_utils import shape_list from ...utils import logging from .configuration_longformer import LongformerConfig diff --git a/src/transformers/models/marian/modeling_tf_marian.py b/src/transformers/models/marian/modeling_tf_marian.py index be9be08fb1..ba094d6a0a 100644 --- a/src/transformers/models/marian/modeling_tf_marian.py +++ b/src/transformers/models/marian/modeling_tf_marian.py @@ -45,8 +45,8 @@ from ...modeling_tf_utils import ( TFWrappedEmbeddings, input_processing, keras_serializable, - shape_list, ) +from ...tf_utils import shape_list from ...utils import logging from .configuration_marian import MarianConfig diff --git a/src/transformers/models/mbart/modeling_tf_mbart.py b/src/transformers/models/mbart/modeling_tf_mbart.py index f98408f8e1..59e41bd694 100644 --- a/src/transformers/models/mbart/modeling_tf_mbart.py +++ b/src/transformers/models/mbart/modeling_tf_mbart.py @@ -44,8 +44,8 @@ from ...modeling_tf_utils import ( TFWrappedEmbeddings, input_processing, keras_serializable, - shape_list, ) +from ...tf_utils import shape_list from ...utils import logging from .configuration_mbart import MBartConfig diff --git a/src/transformers/models/mobilebert/modeling_tf_mobilebert.py b/src/transformers/models/mobilebert/modeling_tf_mobilebert.py index 928e7e8b16..9b16c79f18 100644 --- a/src/transformers/models/mobilebert/modeling_tf_mobilebert.py +++ b/src/transformers/models/mobilebert/modeling_tf_mobilebert.py @@ -51,8 +51,8 @@ from ...modeling_tf_utils import ( get_initializer, input_processing, keras_serializable, - shape_list, ) +from ...tf_utils import shape_list from ...utils import logging from .configuration_mobilebert import MobileBertConfig diff --git a/src/transformers/models/mpnet/modeling_tf_mpnet.py b/src/transformers/models/mpnet/modeling_tf_mpnet.py index 0ed54a2ab1..196a47b1fb 100644 --- a/src/transformers/models/mpnet/modeling_tf_mpnet.py +++ b/src/transformers/models/mpnet/modeling_tf_mpnet.py @@ -47,8 +47,8 @@ from ...modeling_tf_utils import ( get_initializer, input_processing, keras_serializable, - shape_list, ) +from ...tf_utils import shape_list from ...utils import logging from .configuration_mpnet import MPNetConfig diff --git a/src/transformers/models/openai/modeling_tf_openai.py b/src/transformers/models/openai/modeling_tf_openai.py index a924fb4023..cb680603a1 100644 --- a/src/transformers/models/openai/modeling_tf_openai.py +++ b/src/transformers/models/openai/modeling_tf_openai.py @@ -39,8 +39,8 @@ from ...modeling_tf_utils import ( get_initializer, input_processing, keras_serializable, - shape_list, ) +from ...tf_utils import shape_list from ...utils import logging from .configuration_openai import OpenAIGPTConfig diff --git a/src/transformers/models/pegasus/modeling_tf_pegasus.py b/src/transformers/models/pegasus/modeling_tf_pegasus.py index 86f922e7bb..cb14687406 100644 --- a/src/transformers/models/pegasus/modeling_tf_pegasus.py +++ b/src/transformers/models/pegasus/modeling_tf_pegasus.py @@ -45,8 +45,8 @@ from ...modeling_tf_utils import ( TFWrappedEmbeddings, input_processing, keras_serializable, - shape_list, ) +from ...tf_utils import shape_list from ...utils import logging from .configuration_pegasus import PegasusConfig diff --git a/src/transformers/models/rag/modeling_tf_rag.py b/src/transformers/models/rag/modeling_tf_rag.py index 4059b09cd8..7ea2d3521b 100644 --- a/src/transformers/models/rag/modeling_tf_rag.py +++ b/src/transformers/models/rag/modeling_tf_rag.py @@ -1269,6 +1269,8 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss ) if return_dict_in_generate: + # TODO(Patrick): `encoder_outputs`, `past` hack. + # Remove after cleaning encoder-decoder outputs if output_attentions: model_kwargs["encoder_attentions"] = encoder_outputs.attentions if output_hidden_states: @@ -1350,28 +1352,35 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss **model_kwargs, # encoder_outputs is here as in Pytorch's version ) else: - return self._generate_no_beam_search( - decoder_input_ids, - cur_len=cur_len, - max_length=max_length, - min_length=min_length, - do_sample=do_sample, - temperature=temperature, - top_k=top_k, - top_p=top_p, + pre_processor = self._get_logits_processor( repetition_penalty=repetition_penalty, no_repeat_ngram_size=no_repeat_ngram_size, bad_words_ids=bad_words_ids, + min_length=min_length, + eos_token_id=eos_token_id, + ) + # TODO(Patrick) clean-up once generate is fully cleaned up + model_kwargs["attention_mask"] = context_attention_mask + # TODO(Patrick) remove once generate is fully cleaned up + model_kwargs.pop("output_hidden_states", None) + model_kwargs.pop("output_attentions", None) + model_kwargs.pop("output_scores", None) + + # TODO(Patrick): `encoder_outputs`, `past` hack. + # Remove after cleaning encoder-decoder outputs + model_kwargs["past"] = encoder_outputs + + return self.greedy_search( + input_ids=decoder_input_ids, + max_length=max_length, pad_token_id=pad_token_id, eos_token_id=eos_token_id, - batch_size=batch_size, - vocab_size=vocab_size, - attention_mask=context_attention_mask, - use_cache=use_cache, - forced_bos_token_id=None, - forced_eos_token_id=None, + logits_processor=pre_processor, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_scores=output_scores, return_dict_in_generate=return_dict_in_generate, - **model_kwargs, # encoder_outputs is here as in Pytorch's version + **model_kwargs, ) def get_input_embeddings(self): diff --git a/src/transformers/models/rembert/modeling_tf_rembert.py b/src/transformers/models/rembert/modeling_tf_rembert.py index 9bf6ba6ede..24a6387cd7 100644 --- a/src/transformers/models/rembert/modeling_tf_rembert.py +++ b/src/transformers/models/rembert/modeling_tf_rembert.py @@ -51,8 +51,8 @@ from ...modeling_tf_utils import ( get_initializer, input_processing, keras_serializable, - shape_list, ) +from ...tf_utils import shape_list from ...utils import logging from .configuration_rembert import RemBertConfig diff --git a/src/transformers/models/roberta/modeling_tf_roberta.py b/src/transformers/models/roberta/modeling_tf_roberta.py index 9aeb0a1eef..b74863fb20 100644 --- a/src/transformers/models/roberta/modeling_tf_roberta.py +++ b/src/transformers/models/roberta/modeling_tf_roberta.py @@ -52,8 +52,8 @@ from ...modeling_tf_utils import ( get_initializer, input_processing, keras_serializable, - shape_list, ) +from ...tf_utils import shape_list from ...utils import logging from .configuration_roberta import RobertaConfig diff --git a/src/transformers/models/roformer/modeling_tf_roformer.py b/src/transformers/models/roformer/modeling_tf_roformer.py index 57a40a2905..393114df01 100644 --- a/src/transformers/models/roformer/modeling_tf_roformer.py +++ b/src/transformers/models/roformer/modeling_tf_roformer.py @@ -51,8 +51,8 @@ from ...modeling_tf_utils import ( get_initializer, input_processing, keras_serializable, - shape_list, ) +from ...tf_utils import shape_list from ...utils import logging from .configuration_roformer import RoFormerConfig diff --git a/src/transformers/models/speech_to_text/modeling_tf_speech_to_text.py b/src/transformers/models/speech_to_text/modeling_tf_speech_to_text.py index 7c69684e06..0eba94521d 100755 --- a/src/transformers/models/speech_to_text/modeling_tf_speech_to_text.py +++ b/src/transformers/models/speech_to_text/modeling_tf_speech_to_text.py @@ -39,8 +39,8 @@ from ...modeling_tf_utils import ( TFSharedEmbeddings, input_processing, keras_serializable, - shape_list, ) +from ...tf_utils import shape_list from ...utils import logging from .configuration_speech_to_text import Speech2TextConfig diff --git a/src/transformers/models/t5/modeling_tf_t5.py b/src/transformers/models/t5/modeling_tf_t5.py index 5b030342ff..ca307df70e 100644 --- a/src/transformers/models/t5/modeling_tf_t5.py +++ b/src/transformers/models/t5/modeling_tf_t5.py @@ -44,8 +44,8 @@ from ...modeling_tf_utils import ( TFWrappedEmbeddings, input_processing, keras_serializable, - shape_list, ) +from ...tf_utils import shape_list from ...utils import logging from .configuration_t5 import T5Config diff --git a/src/transformers/models/tapas/modeling_tf_tapas.py b/src/transformers/models/tapas/modeling_tf_tapas.py index cdb7e8c113..46baba2627 100644 --- a/src/transformers/models/tapas/modeling_tf_tapas.py +++ b/src/transformers/models/tapas/modeling_tf_tapas.py @@ -45,8 +45,8 @@ from ...modeling_tf_utils import ( get_initializer, input_processing, keras_serializable, - shape_list, ) +from ...tf_utils import shape_list from ...utils import logging from .configuration_tapas import TapasConfig diff --git a/src/transformers/models/transfo_xl/modeling_tf_transfo_xl.py b/src/transformers/models/transfo_xl/modeling_tf_transfo_xl.py index ab8fb6f11b..f1e23f77ec 100644 --- a/src/transformers/models/transfo_xl/modeling_tf_transfo_xl.py +++ b/src/transformers/models/transfo_xl/modeling_tf_transfo_xl.py @@ -34,8 +34,8 @@ from ...modeling_tf_utils import ( get_initializer, input_processing, keras_serializable, - shape_list, ) +from ...tf_utils import shape_list from ...utils import logging from .configuration_transfo_xl import TransfoXLConfig from .modeling_tf_transfo_xl_utilities import TFAdaptiveSoftmaxMask diff --git a/src/transformers/models/transfo_xl/modeling_tf_transfo_xl_utilities.py b/src/transformers/models/transfo_xl/modeling_tf_transfo_xl_utilities.py index 53eb8239a5..af95f348ec 100644 --- a/src/transformers/models/transfo_xl/modeling_tf_transfo_xl_utilities.py +++ b/src/transformers/models/transfo_xl/modeling_tf_transfo_xl_utilities.py @@ -20,7 +20,7 @@ import tensorflow as tf -from ...modeling_tf_utils import shape_list +from ...tf_utils import shape_list class TFAdaptiveSoftmaxMask(tf.keras.layers.Layer): diff --git a/src/transformers/models/vision_encoder_decoder/modeling_tf_vision_encoder_decoder.py b/src/transformers/models/vision_encoder_decoder/modeling_tf_vision_encoder_decoder.py index 06bcbf7c4b..244c836b8c 100644 --- a/src/transformers/models/vision_encoder_decoder/modeling_tf_vision_encoder_decoder.py +++ b/src/transformers/models/vision_encoder_decoder/modeling_tf_vision_encoder_decoder.py @@ -30,13 +30,8 @@ from ...file_utils import ( replace_return_docstrings, ) from ...modeling_tf_outputs import TFBaseModelOutput, TFSeq2SeqLMOutput -from ...modeling_tf_utils import ( - TFCausalLanguageModelingLoss, - TFPreTrainedModel, - get_initializer, - input_processing, - shape_list, -) +from ...modeling_tf_utils import TFCausalLanguageModelingLoss, TFPreTrainedModel, get_initializer, input_processing +from ...tf_utils import shape_list from ...utils import logging from ..auto.configuration_auto import AutoConfig from ..auto.modeling_tf_auto import TFAutoModel, TFAutoModelForCausalLM diff --git a/src/transformers/models/vit/modeling_tf_vit.py b/src/transformers/models/vit/modeling_tf_vit.py index b1e027c964..9a7025c662 100644 --- a/src/transformers/models/vit/modeling_tf_vit.py +++ b/src/transformers/models/vit/modeling_tf_vit.py @@ -32,8 +32,8 @@ from ...modeling_tf_utils import ( get_initializer, input_processing, keras_serializable, - shape_list, ) +from ...tf_utils import shape_list from ...utils import logging from .configuration_vit import ViTConfig diff --git a/src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py index 6c079fcbf2..6ef3a3f98d 100644 --- a/src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py @@ -30,13 +30,8 @@ from ...file_utils import ( replace_return_docstrings, ) from ...modeling_tf_outputs import TFBaseModelOutput, TFCausalLMOutput -from ...modeling_tf_utils import ( - TFPreTrainedModel, - booleans_processing, - get_initializer, - keras_serializable, - shape_list, -) +from ...modeling_tf_utils import TFPreTrainedModel, booleans_processing, get_initializer, keras_serializable +from ...tf_utils import shape_list from ...tokenization_utils_base import BatchEncoding from ...utils import logging from .configuration_wav2vec2 import Wav2Vec2Config diff --git a/src/transformers/models/xlm/modeling_tf_xlm.py b/src/transformers/models/xlm/modeling_tf_xlm.py index 6d6ff088ec..1554fa3103 100644 --- a/src/transformers/models/xlm/modeling_tf_xlm.py +++ b/src/transformers/models/xlm/modeling_tf_xlm.py @@ -50,8 +50,8 @@ from ...modeling_tf_utils import ( get_initializer, input_processing, keras_serializable, - shape_list, ) +from ...tf_utils import shape_list from ...utils import logging from .configuration_xlm import XLMConfig diff --git a/src/transformers/models/xlnet/modeling_tf_xlnet.py b/src/transformers/models/xlnet/modeling_tf_xlnet.py index c31b82d786..ea0f6b6baf 100644 --- a/src/transformers/models/xlnet/modeling_tf_xlnet.py +++ b/src/transformers/models/xlnet/modeling_tf_xlnet.py @@ -44,8 +44,8 @@ from ...modeling_tf_utils import ( get_initializer, input_processing, keras_serializable, - shape_list, ) +from ...tf_utils import shape_list from ...utils import logging from .configuration_xlnet import XLNetConfig diff --git a/src/transformers/tf_utils.py b/src/transformers/tf_utils.py new file mode 100644 index 0000000000..42c744be7a --- /dev/null +++ b/src/transformers/tf_utils.py @@ -0,0 +1,51 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Union + +import numpy as np +import tensorflow as tf + +from .utils import logging + + +logger = logging.get_logger(__name__) + + +def set_tensor_by_indices_to_value(tensor: tf.Tensor, indices: tf.Tensor, value: Union[tf.Tensor, int, float]): + # create value_tensor since tensor value assignment is not possible in TF + return tf.where(indices, value, tensor) + + +def shape_list(tensor: Union[tf.Tensor, np.ndarray]) -> List[int]: + """ + Deal with dynamic shape in tensorflow cleanly. + + Args: + tensor (`tf.Tensor` or `np.ndarray`): The tensor we want the shape of. + + Returns: + `List[int]`: The shape of the tensor as a list. + """ + if isinstance(tensor, np.ndarray): + return list(tensor.shape) + + dynamic = tf.shape(tensor) + + if tensor.shape == tf.TensorShape(None): + return dynamic + + static = tensor.shape.as_list() + + return [dynamic[i] if s is None else s for i, s in enumerate(static)] diff --git a/src/transformers/utils/dummy_tf_objects.py b/src/transformers/utils/dummy_tf_objects.py index 02b401ef39..6bba825a88 100644 --- a/src/transformers/utils/dummy_tf_objects.py +++ b/src/transformers/utils/dummy_tf_objects.py @@ -17,6 +17,48 @@ class TensorFlowBenchmark(metaclass=DummyObject): requires_backends(self, ["tf"]) +class TFLogitsProcessor(metaclass=DummyObject): + _backends = ["tf"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + +class TFLogitsProcessorList(metaclass=DummyObject): + _backends = ["tf"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + +class TFMinLengthLogitsProcessor(metaclass=DummyObject): + _backends = ["tf"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + +class TFNoBadWordsLogitsProcessor(metaclass=DummyObject): + _backends = ["tf"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + +class TFNoRepeatNGramLogitsProcessor(metaclass=DummyObject): + _backends = ["tf"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + +class TFRepetitionPenaltyLogitsProcessor(metaclass=DummyObject): + _backends = ["tf"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + def tf_top_k_top_p_filtering(*args, **kwargs): requires_backends(tf_top_k_top_p_filtering, ["tf"]) diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py index 37b62d5772..3dbe073e68 100644 --- a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py +++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py @@ -53,8 +53,8 @@ from ...modeling_tf_utils import ( get_initializer, input_processing, keras_serializable, - shape_list, ) +from ...tf_utils import shape_list from ...utils import logging from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.camelcase_modelname}}Config @@ -1803,7 +1803,7 @@ from ...modeling_tf_utils import ( TFWrappedEmbeddings, input_processing, keras_serializable, - shape_list, +); from ...tf_utils import (shape_list, ) from ...utils import logging from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.camelcase_modelname}}Config diff --git a/tests/test_generation_tf_logits_process.py b/tests/test_generation_tf_logits_process.py new file mode 100644 index 0000000000..fb9eb086e4 --- /dev/null +++ b/tests/test_generation_tf_logits_process.py @@ -0,0 +1,172 @@ +# coding=utf-8 +# Copyright 2020 The HuggingFace Team Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a clone of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest + +from transformers import is_tf_available +from transformers.testing_utils import require_tf + + +if is_tf_available(): + import tensorflow as tf + + from transformers.generation_tf_logits_process import ( + TFLogitsProcessorList, + TFMinLengthLogitsProcessor, + TFNoBadWordsLogitsProcessor, + TFNoRepeatNGramLogitsProcessor, + TFRepetitionPenaltyLogitsProcessor, + ) + from transformers.tf_utils import set_tensor_by_indices_to_value + + from .test_modeling_tf_common import ids_tensor + + +@require_tf +class TFLogitsProcessorTest(unittest.TestCase): + def _get_uniform_logits(self, batch_size: int, length: int): + scores = tf.ones((batch_size, length), dtype=tf.float32) / length + return scores + + def test_min_length_dist_processor(self): + vocab_size = 20 + batch_size = 4 + eos_token_id = 0 + + min_dist_processor = TFMinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id) + + # check that min length is applied at length 5 + input_ids = ids_tensor((batch_size, 5), vocab_size=20) + scores = self._get_uniform_logits(batch_size, vocab_size) + scores_before_min_length = min_dist_processor(input_ids, scores) + self.assertListEqual(scores_before_min_length[:, eos_token_id].numpy().tolist(), 4 * [-float("inf")]) + + # check that min length is not applied anymore at length 15 + input_ids = ids_tensor((batch_size, 15), vocab_size=20) + scores = self._get_uniform_logits(batch_size, vocab_size) + scores_before_min_length = min_dist_processor(input_ids, scores) + self.assertFalse(tf.math.reduce_any(tf.math.is_inf(scores_before_min_length)).numpy()) + + def test_repetition_penalty_dist_process(self): + input_ids = tf.constant([[0, 1], [5, 0]], dtype=tf.int32) + vocab_size = 10 + + scores = self._get_uniform_logits(batch_size=2, length=vocab_size) + + mask = tf.cast(tf.constant([[1] + 9 * [0], 10 * [0]]), tf.bool) + scores = set_tensor_by_indices_to_value(scores, mask, -1 / vocab_size) + mask = tf.cast(tf.constant([10 * [0], 5 * [0] + [1] + 4 * [0]]), tf.bool) + scores = set_tensor_by_indices_to_value(scores, mask, 4 / vocab_size) + + rep_penalty_proc = TFRepetitionPenaltyLogitsProcessor(penalty=2.0) + + scores = rep_penalty_proc(input_ids, tf.identity(scores)) + + # check that values were correctly changed + self.assertAlmostEqual(scores[0, 0].numpy(), -(1 / vocab_size) * 2) + self.assertAlmostEqual(scores[0, 1].numpy(), (1 / vocab_size) / 2) + + self.assertAlmostEqual(scores[1, 0].numpy(), (1 / vocab_size) / 2) + self.assertAlmostEqual(scores[1, 5].numpy(), (4 / vocab_size) / 2) + + def test_no_repeat_ngram_dist_processor(self): + vocab_size = 3 + batch_size = 2 + + input_ids = tf.constant([[1, 1, 2, 1], [0, 1, 0, 1]], dtype=tf.int32) + scores = self._get_uniform_logits(batch_size, vocab_size) + + no_repeat_proc_2_gram = TFNoRepeatNGramLogitsProcessor(2) + no_repeat_proc_3_gram = TFNoRepeatNGramLogitsProcessor(3) + + filtered_scores_2_gram = no_repeat_proc_2_gram(input_ids, tf.identity(scores)) + filtered_scores_3_gram = no_repeat_proc_3_gram(input_ids, tf.identity(scores)) + + # 2-gram would forbid 2nd and 3rd token (1,2) at 1st batch and 1st token (0) at 2nd batch + self.assertListEqual( + tf.math.is_inf(filtered_scores_2_gram).numpy().tolist(), [[False, True, True], [True, False, False]] + ) + + # 3-gram would forbid no token at 1st batch and 1st token (0) at 2nd batch + self.assertListEqual( + tf.math.is_inf(filtered_scores_3_gram).numpy().tolist(), [[False, False, False], [True, False, False]] + ) + + def test_no_bad_words_dist_processor(self): + vocab_size = 5 + batch_size = 2 + eos_token_id = 4 + + input_ids = tf.constant([[0, 1, 3, 1], [0, 1, 0, 1]], dtype=tf.int32) + bad_word_tokens = [[1], [4], [1, 0], [0, 1, 2], [1, 3, 1, 3]] + scores = self._get_uniform_logits(batch_size, vocab_size) + + no_bad_words_dist_proc = TFNoBadWordsLogitsProcessor(bad_words_ids=bad_word_tokens, eos_token_id=eos_token_id) + + filtered_scores = no_bad_words_dist_proc(input_ids, tf.identity(scores)) + + # batch 1: 1st, 2nd, and 4th (0, 1, 3) token are forbidden + # batch 2: 1st, 2nd, and 3rd (0, 1, 2) token are forbidden + self.assertListEqual( + tf.math.is_inf(filtered_scores).numpy().tolist(), + [[True, True, False, True, True], [True, True, True, False, True]], + ) + + def test_processor_list(self): + batch_size = 4 + sequence_length = 10 + vocab_size = 15 + eos_token_id = 0 + + # dummy input_ids and scores + input_ids = ids_tensor((batch_size, sequence_length), vocab_size) + input_ids_comp = tf.identity(input_ids) + + scores = self._get_uniform_logits(batch_size, vocab_size) + scores_comp = tf.identity(scores) + + # instantiate all dist processors + min_dist_proc = TFMinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id) + rep_penalty_proc = TFRepetitionPenaltyLogitsProcessor(penalty=2.0) + no_repeat_proc = TFNoRepeatNGramLogitsProcessor(2) + no_bad_words_dist_proc = TFNoBadWordsLogitsProcessor(bad_words_ids=[[1]], eos_token_id=eos_token_id) + + # no processor list + scores = min_dist_proc(input_ids, scores) + scores = rep_penalty_proc(input_ids, scores) + scores = no_repeat_proc(input_ids, scores) + scores = no_bad_words_dist_proc(input_ids, scores) + + # with processor list + processor = TFLogitsProcessorList( + [ + min_dist_proc, + rep_penalty_proc, + no_repeat_proc, + no_bad_words_dist_proc, + ] + ) + scores_comp = processor(input_ids, scores_comp) + + # remove inf + scores = set_tensor_by_indices_to_value(scores, tf.math.is_inf(scores), -1e9) + scores_comp = set_tensor_by_indices_to_value(scores_comp, tf.math.is_inf(scores_comp), -1e9) + + # scores should be equal + tf.debugging.assert_near(scores, scores_comp, atol=1e-3) + + # input_ids should never be changed + self.assertListEqual(input_ids.numpy().tolist(), input_ids_comp.numpy().tolist()) diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index 95c953a6e3..f293d8126f 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -955,7 +955,7 @@ class TFModelTesterMixin: # Models with non-text inputs won't work here; num_return_sequences = 1 self._check_generated_ids(model.generate(do_sample=True, max_length=5)) - with self.assertRaises(AssertionError): + with self.assertRaises(ValueError): # generating multiple sequences when no beam search generation # is not allowed as it would always generate the same sequences model.generate(input_ids, do_sample=False, num_return_sequences=2) diff --git a/tests/test_modeling_tf_gpt2.py b/tests/test_modeling_tf_gpt2.py index d653329a5e..4f66ec89f4 100644 --- a/tests/test_modeling_tf_gpt2.py +++ b/tests/test_modeling_tf_gpt2.py @@ -26,14 +26,15 @@ from .test_modeling_tf_core import TFCoreModelTesterMixin if is_tf_available(): import tensorflow as tf + from transformers import GPT2Tokenizer from transformers.models.gpt2.modeling_tf_gpt2 import ( TF_GPT2_PRETRAINED_MODEL_ARCHIVE_LIST, TFGPT2DoubleHeadsModel, TFGPT2ForSequenceClassification, TFGPT2LMHeadModel, TFGPT2Model, - shape_list, ) + from transformers.tf_utils import shape_list class TFGPT2ModelTester: @@ -428,60 +429,53 @@ class TFGPT2ModelTest(TFModelTesterMixin, TFCoreModelTesterMixin, unittest.TestC @require_tf class TFGPT2ModelLanguageGenerationTest(unittest.TestCase): @slow - def test_lm_generate_gpt2(self): - model = TFGPT2LMHeadModel.from_pretrained("gpt2") - input_ids = tf.convert_to_tensor([[464, 3290]], dtype=tf.int32) # The dog - expected_output_ids = [ - 464, - 3290, - 373, - 1043, - 287, - 257, - 2214, - 1474, - 262, - 16246, - 286, - 2688, - 290, - 2688, - 27262, - 13, - 198, - 198, - 464, - 3290, - ] # The dog was found in a field near the intersection of West and West Streets.\n\nThe dog + def test_lm_generate_distilgpt2(self): + model = TFGPT2LMHeadModel.from_pretrained("distilgpt2") + input_ids = tf.convert_to_tensor([[464, 1893]], dtype=tf.int32) # The president + + # The president of the United States, and the president of the United Kingdom, have been in the White + # fmt: off + expected_output_ids = [464, 1893, 286, 262, 1578, 1829, 11, 290, 262, 1893, 286, 262, 1578, 7526, 11, 423, 587, 287, 262, 2635] + # fmt: on + output_ids = model.generate(input_ids, do_sample=False) self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids) @slow - def test_lm_generate_distilgpt2(self): + def test_lm_generate_distilgpt2_batch_special(self): model = TFGPT2LMHeadModel.from_pretrained("distilgpt2") - input_ids = tf.convert_to_tensor([[464, 1893]], dtype=tf.int32) # The president - expected_output_ids = [ - 464, - 1893, - 286, - 262, - 1578, - 1829, - 11, - 290, - 262, - 1893, - 286, - 262, - 1578, - 7526, - 11, - 423, - 587, - 287, - 262, - 2635, - ] # The president of the United States, and the president of the United Kingdom, have been in the White + tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2") + tokenizer.pad_token = tokenizer.eos_token + tokenizer.padding_side = "left" + + sentences = ["Today is a beautiful day and", "Yesterday was"] + input_ids = tokenizer(sentences, return_tensors="tf", padding=True).input_ids + + generation_kwargs = { + "bad_words_ids": [tokenizer("is").input_ids, tokenizer("angry about").input_ids], + "no_repeat_ngram_size": 2, + "do_sample": False, + "repetition_penalty": 1.3, + } + + output_ids = model.generate(input_ids, **generation_kwargs) + + output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True) + expected_output_string = [ + "Today is a beautiful day and I am so happy to be able take part in this amazing event.", + "Yesterday was a very busy day for the first time since I started writing this post", + ] + self.assertListEqual(output_strings, expected_output_string) + + @slow + def test_lm_generate_gpt2(self): + model = TFGPT2LMHeadModel.from_pretrained("gpt2") + input_ids = tf.convert_to_tensor([[464, 3290]], dtype=tf.int32) # The dog + + # The dog was found in a field near the intersection of West and West Streets.\n\nThe dog + # fmt: off + expected_output_ids = [464, 3290, 373, 1043, 287, 257, 2214, 1474, 262, 16246, 286, 2688, 290, 2688, 27262, 13, 198, 198, 464, 3290] + # fmt: on output_ids = model.generate(input_ids, do_sample=False) self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids) diff --git a/tests/test_modeling_tf_longformer.py b/tests/test_modeling_tf_longformer.py index b88437a137..be96de22af 100644 --- a/tests/test_modeling_tf_longformer.py +++ b/tests/test_modeling_tf_longformer.py @@ -36,14 +36,7 @@ if is_tf_available(): TFLongformerModel, TFLongformerSelfAttention, ) - - def shape_list(x): - """ - copied from transformers.modeling_tf_utils - """ - static = x.shape.as_list() - dynamic = tf.shape(x) - return [dynamic[i] if s is None else s for i, s in enumerate(static)] + from transformers.tf_utils import shape_list class TFLongformerModelTester: diff --git a/tests/test_modeling_tf_speech_to_text.py b/tests/test_modeling_tf_speech_to_text.py index e34892bf12..6253ccf953 100644 --- a/tests/test_modeling_tf_speech_to_text.py +++ b/tests/test_modeling_tf_speech_to_text.py @@ -474,7 +474,7 @@ class TFSpeech2TextModelTest(TFModelTesterMixin, unittest.TestCase): # num_return_sequences = 1 self._check_generated_ids(model.generate(input_features, do_sample=True)) - with self.assertRaises(AssertionError): + with self.assertRaises(ValueError): # generating multiple sequences when no beam search generation # is not allowed as it would always generate the same sequences model.generate(input_features, do_sample=False, num_return_sequences=2) diff --git a/tests/test_modeling_tf_t5.py b/tests/test_modeling_tf_t5.py index 67e780f24c..9a5a1de199 100644 --- a/tests/test_modeling_tf_t5.py +++ b/tests/test_modeling_tf_t5.py @@ -453,6 +453,34 @@ class TFT5EncoderOnlyModelTest(TFModelTesterMixin, unittest.TestCase): pass +@require_tf +@require_sentencepiece +@require_tokenizers +class TFT5GenerationIntegrationTests(unittest.TestCase): + @slow + def test_greedy_generate(self): + model = TFT5ForConditionalGeneration.from_pretrained("t5-small") + tokenizer = T5Tokenizer.from_pretrained("t5-small") + + sentences = ["Yesterday, my name was", "Today is a beautiful day and"] + input_ids = tokenizer(sentences, return_tensors="tf", padding=True).input_ids + + generation_kwargs = { + "bad_words_ids": [tokenizer("my").input_ids, tokenizer("ein schöner").input_ids], + "no_repeat_ngram_size": 3, + "do_sample": False, + "repetition_penalty": 2.2, + } + + output_ids = model.generate(input_ids, **generation_kwargs) + + output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True) + + expected_output_string = ["Yesterday, my name was", "Heute ist ein schöne Tag und"] + + self.assertListEqual(expected_output_string, output_strings) + + @require_tf @require_sentencepiece @require_tokenizers