From 4134100363e878693aa41f4a25a667ca46d80a9e Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 3 Mar 2020 15:42:15 +0100 Subject: [PATCH] Add generate() functionality to TF 2.0 (#3063) * add first copy past test to tf 2 generate * add tf top_k_top_p_filter fn * add generate function for TF * add generate function for TF * implemented generate for all models expect transfoXL * implemented generate for all models expect transfoXL * implemented generate for all models expect transfoXL * make style * change permission of test file to correct ones * delete ipdb * delete ipdb * fix bug and finish simple gpt2 integration test * clean test file * clean test file * make style * make style * make style * make style * change import style * change import style * make style * make style * add decorators * add decorators * fix tf ctrl bug dim => axis in TF * make style * make style * refactored test file * refactored test file * take out test_torch_tf_conversion if nothing is defined * take out test_torch_tf_conversion if nothing is defined * remove useless files * remove useless files * fix conflicts * fix conflicts * fix conflicts * fix conflicts * fix conflicts * solve conflicts * solve conflicts * fix conflicts * fix conflicts * merge conflicts * delete ipdb * exposed top_k_top_p_filtering fns * delete weirdly created w! file * add comment to test tf common modeling * fix conflicts * fix conflicts * make style * merge conflicts * make style * change tf.tensor.shape to shape_list(tensor) --- src/transformers/__init__.py | 10 +- src/transformers/modeling_ctrl.py | 8 +- src/transformers/modeling_gpt2.py | 8 +- src/transformers/modeling_tf_ctrl.py | 11 +- src/transformers/modeling_tf_gpt2.py | 7 + src/transformers/modeling_tf_transfo_xl.py | 9 + src/transformers/modeling_tf_utils.py | 426 +++++++++++++++++++++ src/transformers/modeling_tf_xlm.py | 14 + src/transformers/modeling_tf_xlnet.py | 26 ++ src/transformers/modeling_transfo_xl.py | 6 +- src/transformers/modeling_xlnet.py | 6 +- tests/test_modeling_common.py | 136 ++++++- tests/test_modeling_gpt2.py | 40 +- tests/test_modeling_tf_common.py | 143 ++++++- tests/test_modeling_tf_ctrl.py | 1 + tests/test_modeling_tf_gpt2.py | 77 +++- tests/test_modeling_tf_openai_gpt.py | 3 + tests/test_modeling_tf_transfo_xl.py | 5 + tests/test_modeling_tf_xlm.py | 6 + tests/test_modeling_tf_xlnet.py | 12 + 20 files changed, 892 insertions(+), 62 deletions(-) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 1959b254d3..b338ad8515 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -136,7 +136,7 @@ if is_sklearn_available(): # Modeling if is_torch_available(): - from .modeling_utils import PreTrainedModel, prune_layer, Conv1D + from .modeling_utils import PreTrainedModel, prune_layer, Conv1D, top_k_top_p_filtering from .modeling_auto import ( AutoModel, AutoModelForPreTraining, @@ -291,7 +291,13 @@ if is_torch_available(): # TensorFlow if is_tf_available(): - from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, TFSequenceSummary, shape_list + from .modeling_tf_utils import ( + TFPreTrainedModel, + TFSharedEmbeddings, + TFSequenceSummary, + shape_list, + tf_top_k_top_p_filtering, + ) from .modeling_tf_auto import ( TFAutoModel, TFAutoModelForPreTraining, diff --git a/src/transformers/modeling_ctrl.py b/src/transformers/modeling_ctrl.py index 40e076a498..f9c6202861 100644 --- a/src/transformers/modeling_ctrl.py +++ b/src/transformers/modeling_ctrl.py @@ -454,14 +454,12 @@ class CTRLLMHeadModel(CTRLPreTrainedModel): def get_output_embeddings(self): return self.lm_head - def prepare_inputs_for_generation(self, input_ids, **kwargs): + def prepare_inputs_for_generation(self, input_ids, past, **kwargs): # only last token for inputs_ids if past is defined in kwargs - if "past" in kwargs and kwargs["past"]: + if past: input_ids = input_ids[:, -1].unsqueeze(-1) - inputs = {"input_ids": input_ids} - inputs.update(kwargs) - return inputs + return {"input_ids": input_ids, "past": past} @add_start_docstrings_to_callable(CTRL_INPUTS_DOCSTRING) def forward( diff --git a/src/transformers/modeling_gpt2.py b/src/transformers/modeling_gpt2.py index 479f459d2c..b492d7fc37 100644 --- a/src/transformers/modeling_gpt2.py +++ b/src/transformers/modeling_gpt2.py @@ -525,14 +525,12 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): def get_output_embeddings(self): return self.lm_head - def prepare_inputs_for_generation(self, input_ids, **kwargs): + def prepare_inputs_for_generation(self, input_ids, past, **kwargs): # only last token for inputs_ids if past is defined in kwargs - if "past" in kwargs and kwargs["past"]: + if past: input_ids = input_ids[:, -1].unsqueeze(-1) - inputs = {"input_ids": input_ids} - inputs.update(kwargs) - return inputs + return {"input_ids": input_ids, "past": past} @add_start_docstrings_to_callable(GPT2_INPUTS_DOCSTRING) def forward( diff --git a/src/transformers/modeling_tf_ctrl.py b/src/transformers/modeling_tf_ctrl.py index 78e0c1113a..335421979c 100644 --- a/src/transformers/modeling_tf_ctrl.py +++ b/src/transformers/modeling_tf_ctrl.py @@ -105,8 +105,8 @@ class TFMultiHeadAttention(tf.keras.layers.Layer): v = self.split_into_heads(v, batch_size) if layer_past is not None: past_key, past_value = tf.unstack(layer_past, axis=1) - k = tf.concat((past_key, k), dim=-2) - v = tf.concat((past_value, v), dim=-2) + k = tf.concat((past_key, k), axis=-2) + v = tf.concat((past_value, v), axis=-2) present = tf.stack((k, v), axis=1) output = scaled_dot_product_attention(q, k, v, mask, attention_mask, head_mask) @@ -505,6 +505,13 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel): def get_output_embeddings(self): return self.lm_head.input_embeddings + def prepare_inputs_for_generation(self, inputs, past, **kwargs): + # only last token for inputs_ids if past is defined in kwargs + if past: + inputs = tf.expand_dims(inputs[:, -1], -1) + + return {"inputs": inputs, "past": past} + @add_start_docstrings_to_callable(CTRL_INPUTS_DOCSTRING) def call(self, inputs, **kwargs): r""" diff --git a/src/transformers/modeling_tf_gpt2.py b/src/transformers/modeling_tf_gpt2.py index 96a064a332..7e9b102b6d 100644 --- a/src/transformers/modeling_tf_gpt2.py +++ b/src/transformers/modeling_tf_gpt2.py @@ -500,6 +500,13 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel): def get_output_embeddings(self): return self.transformer.wte + def prepare_inputs_for_generation(self, inputs, past, **kwargs): + # only last token for inputs_ids if past is defined in kwargs + if past: + inputs = tf.expand_dims(inputs[:, -1], -1) + + return {"inputs": inputs, "past": past} + @add_start_docstrings_to_callable(GPT2_INPUTS_DOCSTRING) def call(self, inputs, **kwargs): r""" diff --git a/src/transformers/modeling_tf_transfo_xl.py b/src/transformers/modeling_tf_transfo_xl.py index 659685388e..098a4c9143 100644 --- a/src/transformers/modeling_tf_transfo_xl.py +++ b/src/transformers/modeling_tf_transfo_xl.py @@ -826,3 +826,12 @@ class TFTransfoXLLMHeadModel(TFTransfoXLPreTrainedModel): outputs = [softmax_output] + outputs return outputs # logits, new_mems, (all hidden states), (all attentions) + + def prepare_inputs_for_generation(self, inputs, past, **model_kwargs): + inputs = {"inputs": inputs} + + # if past is defined in model kwargs then use it for faster decoding + if past: + inputs["mems"] = past + + return inputs diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index 452d377cd5..43abdd9499 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -384,6 +384,432 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): return model + def prepare_inputs_for_generation(self, inputs, **kwargs): + return {"inputs": inputs} + + def _do_output_past(self, outputs): + has_output_past = hasattr(self.config, "output_past") and self.config.output_past + has_mem_len = hasattr(self.config, "mem_len") and self.config.mem_len + + if has_output_past and not has_mem_len and len(outputs) > 1: + return True + elif has_mem_len and self.config.mem_len > 0 and len(outputs) > 1: + return True + + return False + + def generate( + self, + input_ids=None, + max_length=None, + do_sample=True, + num_beams=None, + temperature=None, + top_k=None, + top_p=None, + repetition_penalty=None, + bos_token_id=None, + pad_token_id=None, + eos_token_ids=None, + length_penalty=None, + num_return_sequences=None, + ): + r""" Generates sequences for models with a LM head. The method currently supports greedy or penalized greedy decoding, sampling with top-k or nucleus sampling + and beam-search. + + Adapted in part from `Facebook's XLM beam search code`_. + + .. _`Facebook's XLM beam search code`: + https://github.com/facebookresearch/XLM/blob/9e6f6814d17be4fe5b15f2e6c43eb2b2d76daeb4/src/model/transformer.py#L529 + + + Parameters: + + input_ids: (`optional`) `torch.LongTensor` of shape `(batch_size, sequence_length)` + The sequence used as a prompt for the generation. If `None` the method initializes + it as an empty `torch.LongTensor` of shape `(1,)`. + + max_length: (`optional`) int + The max length of the sequence to be generated. Between 1 and infinity. Default to 20. + + do_sample: (`optional`) bool + If set to `False` greedy decoding is used. Otherwise sampling is used. Defaults to `True`. + + num_beams: (`optional`) int + Number of beams for beam search. Must be between 1 and infinity. 1 means no beam search. Default to 1. + + temperature: (`optional`) float + The value used to module the next token probabilities. Must be strictely positive. Default to 1.0. + + top_k: (`optional`) int + The number of highest probability vocabulary tokens to keep for top-k-filtering. Between 1 and infinity. Default to 50. + + top_p: (`optional`) float + The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Must be between 0 and 1. Default to 1. + + repetition_penalty: (`optional`) float + The parameter for repetition penalty. Between 1.0 and infinity. 1.0 means no penalty. Default to 1.0. + + bos_token_id: (`optional`) int + Beginning of sentence token if no prompt is provided. Default to 0. + + eos_token_ids: (`optional`) int or list of int + End of sequence token or list of tokens to stop the generation. Default to 0. + length_penalty: (`optional`) float + Exponential penalty to the length. Default to 1. + + num_return_sequences: (`optional`) int + The number of independently computed returned sequences for each element in the batch. Default to 1. + + Return: + + output: `torch.LongTensor` of shape `(batch_size * num_return_sequences, sequence_length)` + sequence_length is either equal to max_length or shorter if all batches finished early due to the `eos_token_id` + + Examples:: + + tokenizer = AutoTokenizer.from_pretrained('distilgpt2') # Initialize tokenizer + model = AutoModelWithLMHead.from_pretrained('distilgpt2') # Download model and configuration from S3 and cache. + outputs = model.generate(max_length=40, bos_token_id=tokenizer.bos_token_id, eos_token_ids=tokenizer.eos_token_id, do_sample=False) # do greedy decoding + print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True))) + + tokenizer = AutoTokenizer.from_pretrained('openai-gpt') # Initialize tokenizer + model = AutoModelWithLMHead.from_pretrained('openai-gpt') # Download model and configuration from S3 and cache. + input_context = 'The dog' + input_ids = torch.tensor(tokenizer.encode(input_context)).unsqueeze(0) # encode input context + outputs = model.generate(input_ids=input_ids, num_beams=5, num_return_sequences=3, temperature=1.5) # generate 3 independent sequences using beam search decoding (5 beams) with sampling from initial context 'The dog' + for i in range(3): # 3 output sequences were generated + print('Generated {}: {}'.format(i, tokenizer.decode(outputs[i], skip_special_tokens=True))) + + tokenizer = AutoTokenizer.from_pretrained('distilgpt2') # Initialize tokenizer + model = AutoModelWithLMHead.from_pretrained('distilgpt2') # Download model and configuration from S3 and cache. + input_context = 'The dog' + input_ids = torch.tensor(tokenizer.encode(input_context)).unsqueeze(0) # encode input context + outputs = model.generate(input_ids=input_ids, max_length=40, temperature=0.7, bos_token_id=tokenizer.bos_token_id, pad_token_id=tokenizer.pad_token_id, eos_token_ids=tokenizer.eos_token_id, num_return_sequences=3) # 3 generate sequences using by sampling + for i in range(3): # 3 output sequences were generated + print('Generated {}: {}'.format(i, tokenizer.decode(outputs[i], skip_special_tokens=True))) + + tokenizer = AutoTokenizer.from_pretrained('ctrl') # Initialize tokenizer + model = AutoModelWithLMHead.from_pretrained('ctrl') # Download model and configuration from S3 and cache. + input_context = 'Legal My neighbor is' # "Legal" is one of the control codes for ctrl + input_ids = torch.tensor(tokenizer.encode(input_context)).unsqueeze(0) # encode input context + outputs = model.generate(input_ids=input_ids, max_length=50, temperature=0.7, repetition_penalty=1.2) # generate sequences + print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True))) + + """ + + # We cannot generate if the model does not have a LM head + if self.get_output_embeddings() is None: + raise AttributeError( + "You tried to generate sequences with a model that does not have a LM Head." + "Please use another model class (e.g. `OpenAIGPTLMHeadModel`, `XLNetLMHeadModel`, `GPT2LMHeadModel`, `CTRLLMHeadModel`, `T5WithLMHeadModel`, `TransfoXLLMHeadModel`)" + ) + + max_length = max_length if max_length is not None else self.config.max_length + do_sample = do_sample if do_sample is not None else self.config.do_sample + num_beams = num_beams if num_beams is not None else self.config.num_beams + temperature = temperature if temperature is not None else self.config.temperature + top_k = top_k if top_k is not None else self.config.top_k + top_p = top_p if top_p is not None else self.config.top_p + repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.repetition_penalty + bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id + pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id + eos_token_ids = eos_token_ids if eos_token_ids is not None else self.config.eos_token_ids + length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty + num_return_sequences = ( + num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences + ) + + if input_ids is not None: + batch_size = shape_list(input_ids)[0] # overriden by the input batch_size + else: + batch_size = 1 + if isinstance(eos_token_ids, int): + eos_token_ids = [eos_token_ids] + + assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictely positive integer." + assert isinstance(do_sample, bool), "`do_sample` should be a boolean." + assert isinstance(num_beams, int) and num_beams > 0, "`num_beams` should be a strictely positive integer." + assert temperature > 0, "`temperature` should be strictely positive." + assert isinstance(top_k, int) and top_k >= 0, "`top_k` should be a positive integer." + assert 0 <= top_p <= 1, "`top_p` should be between 0 and 1." + assert repetition_penalty >= 1.0, "`repetition_penalty` should be >= 1." + assert input_ids is not None or ( + isinstance(bos_token_id, int) and bos_token_id >= 0 + ), "If input_ids is not defined, `bos_token_id` should be a positive integer." + assert pad_token_id is None or ( + isinstance(pad_token_id, int) and (pad_token_id >= 0) + ), "`pad_token_id` should be a positive integer." + assert (eos_token_ids is None) or ( + isinstance(eos_token_ids, (list, tuple)) and ((isinstance(e, int) and e >= 0) for e in eos_token_ids) + ), "`eos_token_ids` should be a positive integer or a list/tuple of positive integers." + assert length_penalty > 0, "`length_penalty` should be strictely positive." + assert ( + isinstance(num_return_sequences, int) and num_return_sequences > 0 + ), "`num_return_sequences` should be a strictely positive integer." + + if input_ids is None: + assert isinstance(bos_token_id, int) and bos_token_id >= 0, ( + "you should either supply a context to complete as `input_ids` input " + "or a `bos_token_id` (integer >= 0) as a first token to start the generation." + ) + input_ids = tf.fill((batch_size, 1), bos_token_id) + else: + assert len(shape_list(input_ids)) == 2, "Input prompt should be of shape (batch_size, sequence length)." + + if pad_token_id is None and eos_token_ids is not None: + logger.warning( + "Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence".format(eos_token_ids[0]) + ) + pad_token_id = eos_token_ids[0] + + # current position and vocab size + cur_len = shape_list(input_ids)[1] + vocab_size = self.config.vocab_size + + if num_return_sequences != 1: + # Expand input to num return sequences + input_ids = tf.broadcast_to(tf.expand_dims(input_ids, 1), (batch_size, num_return_sequences, cur_len)) + effective_batch_size = batch_size * num_return_sequences + input_ids = tf.reshape(input_ids, (effective_batch_size, cur_len)) + else: + effective_batch_size = batch_size + + if num_beams > 1: + output = self._generate_beam_search( + input_ids, + cur_len, + max_length, + do_sample, + temperature, + top_k, + top_p, + repetition_penalty, + pad_token_id, + eos_token_ids, + effective_batch_size, + length_penalty, + num_beams, + vocab_size, + ) + else: + output = self._generate_no_beam_search( + input_ids, + cur_len, + max_length, + do_sample, + temperature, + top_k, + top_p, + repetition_penalty, + pad_token_id, + eos_token_ids, + effective_batch_size, + ) + + return output + + def _generate_no_beam_search( + self, + input_ids, + cur_len, + max_length, + do_sample, + temperature, + top_k, + top_p, + repetition_penalty, + pad_token_id, + eos_token_ids, + batch_size, + ): + """ Generate sequences for each example without beam search (num_beams == 1). + All returned sequence are generated independantly. + """ + + def _create_next_token_logits_penalties(input_ids, logits): + # create logit penalties for already seen input_ids + token_penalties = np.ones(shape_list(logits)) + prev_input_ids = [np.unique(input_id) for input_id in input_ids.numpy()] + for i, prev_input_id in enumerate(prev_input_ids): + logit_penalized = logits[i].numpy()[prev_input_id] + # if previous logit score is < 0 then multiply repetition penalty else divide + logit_penalized[logit_penalized < 0] = repetition_penalty + logit_penalized[logit_penalized > 0] = 1 / repetition_penalty + np.put(token_penalties[i], prev_input_id, logit_penalized) + return tf.convert_to_tensor(token_penalties, dtype=tf.float32) + + # current position / max lengths / length of generated sentences / unfinished sentences + unfinished_sents = tf.ones_like(input_ids[:, 0]) + sent_lengths = tf.ones_like(input_ids[:, 0]) * max_length + + past = None + + while cur_len < max_length: + model_inputs = self.prepare_inputs_for_generation(input_ids, past=past) + outputs = self(**model_inputs) + next_token_logits = outputs[0][:, -1, :] + + # if model has past, then set the past variable to speed up decoding + if self._do_output_past(outputs): + past = outputs[1] + + # repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858) + if repetition_penalty != 1.0: + next_token_logits_penalties = _create_next_token_logits_penalties(input_ids, next_token_logits) + next_token_logits = tf.math.multiply(next_token_logits, next_token_logits_penalties) + + if do_sample: + # Temperature (higher temperature => more likely to sample low probability tokens) + if temperature != 1.0: + next_token_logits = next_token_logits / temperature + # Top-p/top-k filtering + next_token_logits = tf_top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p) + # Sample + next_token = tf.squeeze( + tf.random.categorical(next_token_logits, dtype=tf.int32, num_samples=1), axis=1 + ) + else: + # Greedy decoding + next_token = tf.math.argmax(next_token_logits, axis=-1, output_type=tf.int32) + + # update generations and finished sentences + if eos_token_ids is not None: + # pad finished sentences if eos_token_ids exist + tokens_to_add = next_token * unfinished_sents + (pad_token_id) * (1 - unfinished_sents) + else: + tokens_to_add = next_token + + input_ids = tf.concat([input_ids, tf.expand_dims(tokens_to_add, -1)], 1) + + if eos_token_ids is not None: + for eos_token_id in eos_token_ids: + eos_in_sents = tokens_to_add == eos_token_id + # if sentence is unfinished and the token to add is eos, sent_lengths is filled with current length + is_sents_unfinished_and_token_to_add_is_eos = tf.math.multiply( + unfinished_sents, tf.cast(eos_in_sents, tf.int32) + ) + sent_lengths = ( + sent_lengths * (1 - is_sents_unfinished_and_token_to_add_is_eos) + + cur_len * is_sents_unfinished_and_token_to_add_is_eos + ) + + # unfinished_sents is set to zero if eos in sentence + unfinished_sents -= is_sents_unfinished_and_token_to_add_is_eos + + cur_len = cur_len + 1 + + # stop when there is a in each sentence, or if we exceed the maximul length + if tf.math.reduce_max(unfinished_sents) == 0: + break + + # if there are different sentences lengths in the batch, some batches have to be padded + min_sent_length = tf.math.reduce_min(sent_lengths) + max_sent_length = tf.math.reduce_max(sent_lengths) + if min_sent_length != max_sent_length: + assert pad_token_id is not None, "`Pad_token_id` has to be defined if batches have different lengths" + # finished sents are filled with pad_token + padding = tf.ones([batch_size, max_sent_length.numpy()], dtype=tf.int32) * pad_token_id + + # create length masks for tf.where operation + broad_casted_sent_lengths = tf.broadcast_to( + tf.expand_dims(sent_lengths, -1), [batch_size, max_sent_length] + ) + broad_casted_range = tf.transpose( + tf.broadcast_to(tf.expand_dims(tf.range(max_length), -1), [max_length, batch_size]) + ) + + decoded = tf.where(broad_casted_range < broad_casted_sent_lengths, input_ids, padding) + else: + decoded = input_ids + + return decoded + + def _generate_beam_search( + self, + input_ids, + cur_len, + max_length, + do_sample, + temperature, + top_k, + top_p, + repetition_penalty, + pad_token_id, + eos_token_ids, + batch_size, + length_penalty, + num_beams, + vocab_size, + ): + pass + + +def tf_top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1): + """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering + Args: + logits: logits distribution shape (batch size, vocabulary size) + if top_k > 0: keep only top k tokens with highest probability (top-k filtering). + if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). + Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) + Make sure we keep at least min_tokens_to_keep per batch example in the output + From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 + """ + logits_shape = shape_list(logits) + + if top_k > 0: + top_k = min(max(top_k, min_tokens_to_keep), logits_shape[-1]) # Safety check + # Remove all tokens with a probability less than the last token of the top-k + indices_to_remove = logits < tf.math.top_k(logits, k=top_k)[0][..., -1, None] + logits = set_tensor_by_indices_to_value(logits, indices_to_remove, filter_value) + + if top_p < 1.0: + sorted_indices = tf.argsort(logits, direction="DESCENDING") + sorted_logits = tf.gather( + logits, sorted_indices, axis=-1, batch_dims=1 + ) # expects logits to be of dim (batch_size, vocab_size) + + cumulative_probs = tf.math.cumsum(tf.nn.softmax(sorted_logits, axis=-1), axis=-1) + + # Remove tokens with cumulative probability above the threshold (token with 0 are kept) + sorted_indices_to_remove = cumulative_probs > top_p + + if min_tokens_to_keep > 1: + # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) + sorted_indices_to_remove = tf.concat( + [ + tf.zeros_like(sorted_indices_to_remove[:, :min_tokens_to_keep]), + sorted_indices_to_remove[:, min_tokens_to_keep:], + ], + -1, + ) + + # Shift the indices to the right to keep also the first token above the threshold + sorted_indices_to_remove = tf.roll(sorted_indices_to_remove, 1, axis=-1) + sorted_indices_to_remove = tf.concat( + [tf.zeros_like(sorted_indices_to_remove[:, :1]), sorted_indices_to_remove[:, 1:]], -1, + ) + # scatter sorted tensors to original indexing + indices_to_remove = scatter_values_on_batch_indices(sorted_indices_to_remove, sorted_indices) + logits = set_tensor_by_indices_to_value(logits, indices_to_remove, filter_value) + return logits + + +def scatter_values_on_batch_indices(values, batch_indices): + shape = shape_list(batch_indices) + # broadcast batch dim to shape + broad_casted_batch_dims = tf.reshape(tf.broadcast_to(tf.expand_dims(tf.range(shape[0]), axis=-1), shape), [1, -1]) + # transform batch_indices to pair_indices + pair_indices = tf.transpose(tf.concat([broad_casted_batch_dims, tf.reshape(batch_indices, [1, -1])], 0)) + # scatter values to pair indices + return tf.scatter_nd(pair_indices, tf.reshape(values, [-1]), shape) + + +def set_tensor_by_indices_to_value(tensor, indices, value): + # create value_tensor since tensor value assignment is not possible in TF + value_tensor = tf.zeros_like(tensor) + value + return tf.where(indices, value_tensor, tensor) + class TFConv1D(tf.keras.layers.Layer): def __init__(self, nf, nx, initializer_range=0.02, **kwargs): diff --git a/src/transformers/modeling_tf_xlm.py b/src/transformers/modeling_tf_xlm.py index 44b991d08c..6e94a7206e 100644 --- a/src/transformers/modeling_tf_xlm.py +++ b/src/transformers/modeling_tf_xlm.py @@ -657,6 +657,20 @@ class TFXLMWithLMHeadModel(TFXLMPreTrainedModel): def get_output_embeddings(self): return self.pred_layer.input_embeddings + def prepare_inputs_for_generation(self, inputs, **kwargs): + mask_token_id = self.config.mask_token_id + lang_id = self.config.lang_id + + effective_batch_size = inputs.shape[0] + mask_token = tf.ones((effective_batch_size, 1), dtype=tf.int32) * mask_token_id + inputs = tf.concat([inputs, mask_token], axis=1) + + if lang_id is not None: + langs = tf.ones_like(inputs) * lang_id + else: + langs = None + return {"inputs": inputs, "langs": langs} + @add_start_docstrings_to_callable(XLM_INPUTS_DOCSTRING) def call(self, inputs, **kwargs): r""" diff --git a/src/transformers/modeling_tf_xlnet.py b/src/transformers/modeling_tf_xlnet.py index d9ced75384..87ebe16858 100644 --- a/src/transformers/modeling_tf_xlnet.py +++ b/src/transformers/modeling_tf_xlnet.py @@ -837,6 +837,32 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel): def get_output_embeddings(self): return self.lm_loss.input_embeddings + def prepare_inputs_for_generation(self, inputs, past, **model_kwargs): + # Add dummy token at the end (no attention on this one) + + effective_batch_size = inputs.shape[0] + dummy_token = tf.zeros((effective_batch_size, 1), dtype=tf.int32) + inputs = tf.concat([inputs, dummy_token], axis=1) + + # Build permutation mask so that previous tokens don't see last token + sequence_length = inputs.shape[1] + perm_mask = tf.zeros((effective_batch_size, sequence_length, sequence_length - 1), dtype=tf.float32) + perm_mask_seq_end = tf.ones((effective_batch_size, sequence_length, 1), dtype=tf.float32) + perm_mask = tf.concat([perm_mask, perm_mask_seq_end], axis=-1) + + # We'll only predict the last token + target_mapping = tf.zeros((effective_batch_size, 1, sequence_length - 1), dtype=tf.float32) + target_mapping_seq_end = tf.ones((effective_batch_size, 1, 1), dtype=tf.float32) + target_mapping = tf.concat([target_mapping, target_mapping_seq_end], axis=-1) + + inputs = {"inputs": inputs, "perm_mask": perm_mask, "target_mapping": target_mapping} + + # if past is defined in model kwargs then use it for faster decoding + if past: + inputs["mems"] = past + + return inputs + @add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING) def call(self, inputs, **kwargs): r""" diff --git a/src/transformers/modeling_transfo_xl.py b/src/transformers/modeling_transfo_xl.py index 3d95d6e70f..379b650bea 100644 --- a/src/transformers/modeling_transfo_xl.py +++ b/src/transformers/modeling_transfo_xl.py @@ -935,11 +935,11 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel): else: return self.crit.out_layers[-1] - def prepare_inputs_for_generation(self, input_ids, **model_kwargs): + def prepare_inputs_for_generation(self, input_ids, past, **model_kwargs): inputs = {"input_ids": input_ids} # if past is defined in model kwargs then use it for faster decoding - if "past" in model_kwargs and model_kwargs["past"]: - inputs["mems"] = model_kwargs["past"] + if past: + inputs["mems"] = past return inputs diff --git a/src/transformers/modeling_xlnet.py b/src/transformers/modeling_xlnet.py index 531b0f9a4c..7d34e7ef2b 100644 --- a/src/transformers/modeling_xlnet.py +++ b/src/transformers/modeling_xlnet.py @@ -935,7 +935,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel): def get_output_embeddings(self): return self.lm_loss - def prepare_inputs_for_generation(self, input_ids, **model_kwargs): + def prepare_inputs_for_generation(self, input_ids, past, **model_kwargs): # Add dummy token at the end (no attention on this one) effective_batch_size = input_ids.shape[0] @@ -958,8 +958,8 @@ class XLNetLMHeadModel(XLNetPreTrainedModel): inputs = {"input_ids": input_ids, "perm_mask": perm_mask, "target_mapping": target_mapping} # if past is defined in model kwargs then use it for faster decoding - if "past" in model_kwargs and model_kwargs["past"]: - inputs["mems"] = model_kwargs["past"] + if past: + inputs["mems"] = past return inputs diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 5277864eca..9ba00d2421 100644 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -36,6 +36,7 @@ if is_torch_available(): BertModel, BertConfig, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, + top_k_top_p_filtering, ) @@ -263,7 +264,7 @@ class ModelTesterMixin: # Prepare head_mask # Set require_grad after having prepared the tensor to avoid error (leaf variable has been moved into the graph interior) head_mask = torch.ones( - self.model_tester.num_hidden_layers, self.model_tester.num_attention_heads, device=torch_device + self.model_tester.num_hidden_layers, self.model_tester.num_attention_heads, device=torch_device, ) head_mask[0, 0] = 0 head_mask[-1, :-1] = 0 @@ -303,7 +304,7 @@ class ModelTesterMixin: return for model_class in self.all_model_classes: - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + (config, inputs_dict,) = self.model_tester.prepare_config_and_inputs_for_common() if "head_mask" in inputs_dict: del inputs_dict["head_mask"] @@ -313,7 +314,10 @@ class ModelTesterMixin: model = model_class(config=config) model.to(torch_device) model.eval() - heads_to_prune = {0: list(range(1, self.model_tester.num_attention_heads)), -1: [0]} + heads_to_prune = { + 0: list(range(1, self.model_tester.num_attention_heads)), + -1: [0], + } model.prune_heads(heads_to_prune) with torch.no_grad(): outputs = model(**inputs_dict) @@ -329,7 +333,7 @@ class ModelTesterMixin: return for model_class in self.all_model_classes: - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + (config, inputs_dict,) = self.model_tester.prepare_config_and_inputs_for_common() if "head_mask" in inputs_dict: del inputs_dict["head_mask"] @@ -339,7 +343,10 @@ class ModelTesterMixin: model = model_class(config=config) model.to(torch_device) model.eval() - heads_to_prune = {0: list(range(1, self.model_tester.num_attention_heads)), -1: [0]} + heads_to_prune = { + 0: list(range(1, self.model_tester.num_attention_heads)), + -1: [0], + } model.prune_heads(heads_to_prune) with tempfile.TemporaryDirectory() as temp_dir_name: @@ -359,7 +366,7 @@ class ModelTesterMixin: return for model_class in self.all_model_classes: - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + (config, inputs_dict,) = self.model_tester.prepare_config_and_inputs_for_common() if "head_mask" in inputs_dict: del inputs_dict["head_mask"] @@ -367,7 +374,10 @@ class ModelTesterMixin: config.output_attentions = True config.output_hidden_states = False - heads_to_prune = {0: list(range(1, self.model_tester.num_attention_heads)), -1: [0]} + heads_to_prune = { + 0: list(range(1, self.model_tester.num_attention_heads)), + -1: [0], + } config.pruned_heads = heads_to_prune model = model_class(config=config) @@ -387,7 +397,7 @@ class ModelTesterMixin: return for model_class in self.all_model_classes: - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + (config, inputs_dict,) = self.model_tester.prepare_config_and_inputs_for_common() if "head_mask" in inputs_dict: del inputs_dict["head_mask"] @@ -465,7 +475,7 @@ class ModelTesterMixin: ) def test_resize_tokens_embeddings(self): - original_config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + (original_config, inputs_dict,) = self.model_tester.prepare_config_and_inputs_for_common() if not self.test_resize_embeddings: return @@ -634,6 +644,7 @@ class ModelTesterMixin: self._check_generated_tokens(model.generate(input_ids, num_return_sequences=3)) # batch_size > 1, greedy self._check_generated_tokens(model.generate(input_ids, do_sample=False)) + # batch_size > 1, num_beams > 1, sample self._check_generated_tokens(model.generate(input_ids, num_beams=3, num_return_sequences=3,)) # batch_size > 1, num_beams > 1, greedy @@ -704,3 +715,110 @@ class ModelUtilsTest(unittest.TestCase): self.assertEqual(model.config.output_attentions, True) self.assertEqual(model.config.output_hidden_states, True) self.assertEqual(model.config, config) + + +@require_torch +class UtilsFunctionsTest(unittest.TestCase): + + # tests whether the top_k_top_p function behaves as expected + def test_top_k_top_p_filtering(self): + logits = torch.tensor( + [ + [ + 8.2220991, # 3rd highest value; idx. 0 + -0.5620044, + 5.23229752, + 4.0386393, + -6.8798378, + -0.54785802, + -3.2012153, + 2.92777176, + 1.88171953, + 7.35341276, # 5th highest value; idx. 9 + 8.43207833, # 2nd highest value; idx. 10 + -9.85711836, + -5.96209236, + -1.13039161, + -7.1115294, + -0.8369633, + -5.3186408, + 7.06427407, + 0.81369344, + -0.82023817, + -5.9179796, + 0.58813443, + -6.99778438, + 4.71551189, + -0.18771637, + 7.44020759, # 4th highest value; idx. 25 + 9.38450987, # 1st highest value; idx. 26 + 2.12662941, + -9.32562038, + 2.35652522, + ], # cummulative prob of 5 highest values <= 0.6 + [ + 0.58425518, + 4.53139238, + -5.57510464, + -6.28030699, + -7.19529503, + -4.02122551, + 1.39337037, + -6.06707057, + 1.59480517, + -9.643119, + 0.03907799, + 0.67231762, + -8.88206726, + 6.27115922, # 4th highest value; idx. 13 + 2.28520723, + 4.82767506, + 4.30421368, + 8.8275313, # 2nd highest value; idx. 17 + 5.44029958, # 5th highest value; idx. 18 + -4.4735794, + 7.38579536, # 3rd highest value; idx. 20 + -2.91051663, + 2.61946077, + -2.5674762, + -9.48959302, + -4.02922645, + -1.35416918, + 9.67702323, # 1st highest value; idx. 27 + -5.89478553, + 1.85370467, + ], # cummulative prob of 5 highest values <= 0.6 + ], + dtype=torch.float, + device=torch_device, + ) + + non_inf_expected_idx = torch.tensor( + [[0, 0], [0, 9], [0, 10], [0, 25], [0, 26], [1, 13], [1, 17], [1, 18], [1, 20], [1, 27]], + dtype=torch.long, + device=torch_device, + ) # expected non filtered idx as noted above + + non_inf_expected_output = torch.tensor( + [ + 8.2221, + 7.3534, + 8.4321, + 7.4402, + 9.3845, + 6.2712, + 8.8275, + 5.4403, + 7.3858, + 9.6770, + ], # expected non filtered values as noted above + dtype=torch.float, + device=torch_device, + ) + + output = top_k_top_p_filtering(logits, top_k=10, top_p=0.6, min_tokens_to_keep=4) + non_inf_output = output[output != -float("inf")].to(device=torch_device) + non_inf_idx = (output != -float("inf")).nonzero().to(device=torch_device) + + self.assertTrue(torch.allclose(non_inf_expected_output, non_inf_output, atol=1e-12)) + self.assertTrue(torch.all(torch.eq(non_inf_expected_idx, non_inf_idx))) diff --git a/tests/test_modeling_gpt2.py b/tests/test_modeling_gpt2.py index 21fc873234..e705b80f8b 100644 --- a/tests/test_modeling_gpt2.py +++ b/tests/test_modeling_gpt2.py @@ -386,33 +386,33 @@ class GPT2ModelLanguageGenerationTest(unittest.TestCase): @slow def test_lm_generate_distilgpt2(self): model = GPT2LMHeadModel.from_pretrained("distilgpt2") - input_ids = torch.Tensor([[464, 3290, 318, 13779]]).long() # The dog is cute + input_ids = torch.Tensor([[464, 1893]]).long() # The president expected_output_ids = [ 464, - 3290, - 318, - 13779, - 996, - 339, - 460, - 3360, - 655, - 2513, + 1893, + 286, + 262, + 1578, + 1829, + 11, + 290, + 262, + 1893, + 286, + 262, + 1578, + 7526, + 11, + 423, + 587, 287, 262, - 3952, - 13, - 632, - 318, - 407, - 845, - 3621, - 284, - ] # The dog is cute though he can sometimes just walk in the park. It is not very nice to - torch.manual_seed(0) + 2635, + ] # The president of the United States, and the president of the United Kingdom, have been in the White output_ids = model.generate( input_ids, + do_sample=False, bos_token_id=self.special_tokens["bos_token_id"], eos_token_ids=self.special_tokens["eos_token_id"], ) diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index e6f70d6bfa..8cd53dfe19 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -18,6 +18,7 @@ import copy import os import random import tempfile +import unittest from transformers import is_tf_available, is_torch_available @@ -28,6 +29,8 @@ if is_tf_available(): import tensorflow as tf import numpy as np + from transformers import tf_top_k_top_p_filtering + if _tf_gpu_memory_limit is not None: gpus = tf.config.list_physical_devices("GPU") for gpu in gpus: @@ -56,6 +59,7 @@ class TFModelTesterMixin: model_tester = None all_model_classes = () + all_generative_model_classes = () test_torchscript = True test_pruning = True test_resize_embeddings = True @@ -216,7 +220,7 @@ class TFModelTesterMixin: outputs_dict = model(inputs_dict) inputs_keywords = copy.deepcopy(inputs_dict) - input_ids = inputs_keywords.pop("input_ids" if not self.is_encoder_decoder else "decoder_input_ids", None) + input_ids = inputs_keywords.pop("input_ids" if not self.is_encoder_decoder else "decoder_input_ids", None,) outputs_keywords = model(input_ids, **inputs_keywords) output_dict = outputs_dict[0].numpy() @@ -299,7 +303,7 @@ class TFModelTesterMixin: self.assertEqual(model.config.output_hidden_states, True) self.assertEqual(len(hidden_states), self.model_tester.num_hidden_layers + 1) self.assertListEqual( - list(hidden_states[0].shape[-2:]), [self.model_tester.seq_length, self.model_tester.hidden_size] + list(hidden_states[0].shape[-2:]), [self.model_tester.seq_length, self.model_tester.hidden_size], ) def test_model_common_attributes(self): @@ -316,7 +320,10 @@ class TFModelTesterMixin: for model_class in self.all_model_classes: model = model_class(config) - first, second = model(inputs_dict, training=False)[0], model(inputs_dict, training=False)[0] + first, second = ( + model(inputs_dict, training=False)[0], + model(inputs_dict, training=False)[0], + ) out_1 = first.numpy() out_2 = second.numpy() out_1 = out_1[~np.isnan(out_1)] @@ -338,9 +345,9 @@ class TFModelTesterMixin: x = wte([input_ids, None, None, None], mode="embedding") except Exception: if hasattr(self.model_tester, "embedding_size"): - x = tf.ones(input_ids.shape + [self.model_tester.embedding_size], dtype=tf.dtypes.float32) + x = tf.ones(input_ids.shape + [self.model_tester.embedding_size], dtype=tf.dtypes.float32,) else: - x = tf.ones(input_ids.shape + [self.model_tester.hidden_size], dtype=tf.dtypes.float32) + x = tf.ones(input_ids.shape + [self.model_tester.hidden_size], dtype=tf.dtypes.float32,) return x def test_inputs_embeds(self): @@ -366,6 +373,37 @@ class TFModelTesterMixin: model(inputs_dict) + def test_lm_head_model_random_generate(self): + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + input_ids = inputs_dict.get( + "input_ids", None + ) # TODO (PVP): ugly workaround to make code work for t5 for the moment - has to changed when t5 is fixed. + + for model_class in self.all_generative_model_classes: + # TODO (PVP): add beam search tests when beam search is implemented + model = model_class(config) + + if config.bos_token_id is None: + with self.assertRaises(AssertionError): + model.generate(max_length=5) + # batch_size = 1 + self._check_generated_tokens(model.generate(input_ids)) + else: + # batch_size = 1 + self._check_generated_tokens(model.generate(max_length=5)) + # batch_size = 1, num_beams > 1 + + # batch_size > 1, sample + self._check_generated_tokens(model.generate(input_ids, num_return_sequences=3)) + # batch_size > 1, greedy + self._check_generated_tokens(model.generate(input_ids, do_sample=False, num_return_sequences=3)) + + def _check_generated_tokens(self, output_ids): + for token_id in output_ids[0].numpy().tolist(): + self.assertGreaterEqual(token_id, 0) + self.assertLess(token_id, self.model_tester.vocab_size) + def ids_tensor(shape, vocab_size, rng=None, name=None, dtype=None): """Creates a random int32 tensor of the shape within the vocab size.""" @@ -383,3 +421,98 @@ def ids_tensor(shape, vocab_size, rng=None, name=None, dtype=None): output = tf.constant(values, shape=shape, dtype=dtype if dtype is not None else tf.int32) return output + + +@require_tf +class UtilsFunctionsTest(unittest.TestCase): + + # tests whether the top_k_top_p_filtering function behaves as expected + def test_top_k_top_p_filtering(self): + logits = tf.convert_to_tensor( + [ + [ + 8.2220991, # 3rd highest value; idx. 0 + -0.5620044, + 5.23229752, + 4.0386393, + -6.8798378, + -0.54785802, + -3.2012153, + 2.92777176, + 1.88171953, + 7.35341276, # 5th highest value; idx. 9 + 8.43207833, # 2nd highest value; idx. 10 + -9.85711836, + -5.96209236, + -1.13039161, + -7.1115294, + -0.8369633, + -5.3186408, + 7.06427407, + 0.81369344, + -0.82023817, + -5.9179796, + 0.58813443, + -6.99778438, + 4.71551189, + -0.18771637, + 7.44020759, # 4th highest value; idx. 25 + 9.38450987, # 1st highest value; idx. 26 + 2.12662941, + -9.32562038, + 2.35652522, + ], # cummulative prob of 5 highest values <= 0.6 + [ + 0.58425518, + 4.53139238, + -5.57510464, + -6.28030699, + -7.19529503, + -4.02122551, + 1.39337037, + -6.06707057, + 1.59480517, + -9.643119, + 0.03907799, + 0.67231762, + -8.88206726, + 6.27115922, # 4th highest value; idx. 13 + 2.28520723, + 4.82767506, + 4.30421368, + 8.8275313, # 2nd highest value; idx. 17 + 5.44029958, # 5th highest value; idx. 18 + -4.4735794, + 7.38579536, # 3rd highest value; idx. 20 + -2.91051663, + 2.61946077, + -2.5674762, + -9.48959302, + -4.02922645, + -1.35416918, + 9.67702323, # 1st highest value; idx. 27 + -5.89478553, + 1.85370467, + ], # cummulative prob of 5 highest values <= 0.6 + ], + dtype=tf.float32, + ) + + non_inf_expected_idx = tf.convert_to_tensor( + [[0, 0], [0, 9], [0, 10], [0, 25], [0, 26], [1, 13], [1, 17], [1, 18], [1, 20], [1, 27]], dtype=tf.int32, + ) # expected non filtered idx as noted above + + non_inf_expected_output = tf.convert_to_tensor( + [8.222099, 7.3534126, 8.432078, 7.4402075, 9.38451, 6.271159, 8.827531, 5.4402995, 7.3857956, 9.677023], + dtype=tf.float32, + ) # expected non filtered values as noted above + + output = tf_top_k_top_p_filtering(logits, top_k=10, top_p=0.6, min_tokens_to_keep=4) + + non_inf_output = output[output != -float("inf")] + non_inf_idx = tf.cast( + tf.where(tf.not_equal(output, tf.constant(-float("inf"), dtype=tf.float32))), dtype=tf.int32, + ) + + tf.debugging.assert_near(non_inf_output, non_inf_expected_output, rtol=1e-12) + tf.debugging.assert_equal(non_inf_idx, non_inf_expected_idx) diff --git a/tests/test_modeling_tf_ctrl.py b/tests/test_modeling_tf_ctrl.py index 4997c2a573..29a6eb5d43 100644 --- a/tests/test_modeling_tf_ctrl.py +++ b/tests/test_modeling_tf_ctrl.py @@ -31,6 +31,7 @@ if is_tf_available(): class TFCTRLModelTest(TFModelTesterMixin, unittest.TestCase): all_model_classes = (TFCTRLModel, TFCTRLLMHeadModel) if is_tf_available() else () + all_generative_model_classes = (TFCTRLLMHeadModel,) if is_tf_available() else () class TFCTRLModelTester(object): def __init__( diff --git a/tests/test_modeling_tf_gpt2.py b/tests/test_modeling_tf_gpt2.py index d7b0809964..362f9e3162 100644 --- a/tests/test_modeling_tf_gpt2.py +++ b/tests/test_modeling_tf_gpt2.py @@ -37,7 +37,7 @@ if is_tf_available(): class TFGPT2ModelTest(TFModelTesterMixin, unittest.TestCase): all_model_classes = (TFGPT2Model, TFGPT2LMHeadModel, TFGPT2DoubleHeadsModel) if is_tf_available() else () - # all_model_classes = (TFGPT2Model, TFGPT2LMHeadModel) if is_tf_available() else () + all_generative_model_classes = (TFGPT2LMHeadModel,) if is_tf_available() else () class TFGPT2ModelTester(object): def __init__( @@ -89,6 +89,8 @@ class TFGPT2ModelTest(TFModelTesterMixin, unittest.TestCase): self.num_labels = num_labels self.num_choices = num_choices self.scope = scope + self.bos_token_id = vocab_size - 1 + self.eos_token_id = vocab_size - 1 def prepare_config_and_inputs(self): input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) @@ -123,9 +125,11 @@ class TFGPT2ModelTest(TFModelTesterMixin, unittest.TestCase): # hidden_dropout_prob=self.hidden_dropout_prob, # attention_probs_dropout_prob=self.attention_probs_dropout_prob, n_positions=self.max_position_embeddings, - n_ctx=self.max_position_embeddings + n_ctx=self.max_position_embeddings, # type_vocab_size=self.type_vocab_size, # initializer_range=self.initializer_range + bos_token_id=self.bos_token_id, + eos_token_ids=self.eos_token_id, ) head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2) @@ -144,7 +148,11 @@ class TFGPT2ModelTest(TFModelTesterMixin, unittest.TestCase): def create_and_check_gpt2_model(self, config, input_ids, input_mask, head_mask, token_type_ids, *args): model = TFGPT2Model(config=config) - inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids} + inputs = { + "input_ids": input_ids, + "attention_mask": input_mask, + "token_type_ids": token_type_ids, + } sequence_output = model(inputs)[0] inputs = [input_ids, None, input_mask] # None is the input for 'past' @@ -156,18 +164,22 @@ class TFGPT2ModelTest(TFModelTesterMixin, unittest.TestCase): "sequence_output": sequence_output.numpy(), } self.parent.assertListEqual( - list(result["sequence_output"].shape), [self.batch_size, self.seq_length, self.hidden_size] + list(result["sequence_output"].shape), [self.batch_size, self.seq_length, self.hidden_size], ) def create_and_check_gpt2_lm_head(self, config, input_ids, input_mask, head_mask, token_type_ids, *args): model = TFGPT2LMHeadModel(config=config) - inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids} + inputs = { + "input_ids": input_ids, + "attention_mask": input_mask, + "token_type_ids": token_type_ids, + } prediction_scores = model(inputs)[0] result = { "prediction_scores": prediction_scores.numpy(), } self.parent.assertListEqual( - list(result["prediction_scores"].shape), [self.batch_size, self.seq_length, self.vocab_size] + list(result["prediction_scores"].shape), [self.batch_size, self.seq_length, self.vocab_size], ) def create_and_check_gpt2_double_head( @@ -188,7 +200,7 @@ class TFGPT2ModelTest(TFModelTesterMixin, unittest.TestCase): lm_logits, mc_logits = model(inputs)[:2] result = {"lm_logits": lm_logits.numpy(), "mc_logits": mc_logits.numpy()} self.parent.assertListEqual( - list(result["lm_logits"].shape), [self.batch_size, self.num_choices, self.seq_length, self.vocab_size] + list(result["lm_logits"].shape), [self.batch_size, self.num_choices, self.seq_length, self.vocab_size], ) self.parent.assertListEqual(list(result["mc_logits"].shape), [self.batch_size, self.num_choices]) @@ -207,7 +219,11 @@ class TFGPT2ModelTest(TFModelTesterMixin, unittest.TestCase): choice_labels, ) = config_and_inputs - inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": input_mask} + inputs_dict = { + "input_ids": input_ids, + "token_type_ids": token_type_ids, + "attention_mask": input_mask, + } return config, inputs_dict def setUp(self): @@ -234,3 +250,48 @@ class TFGPT2ModelTest(TFModelTesterMixin, unittest.TestCase): for model_name in list(TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: model = TFGPT2Model.from_pretrained(model_name, cache_dir=CACHE_DIR) self.assertIsNotNone(model) + + +def prepare_generation_special_tokens(): + return {"bos_token_id": 50256, "eos_token_id": 50256} + + +class TFGPT2ModelLanguageGenerationTest(unittest.TestCase): + + special_tokens = prepare_generation_special_tokens() + + @slow + def test_lm_generate_distilgpt2(self): + model = TFGPT2LMHeadModel.from_pretrained("distilgpt2") + input_ids = tf.convert_to_tensor([[464, 1893]], dtype=tf.int32) # The president + expected_output_ids = [ + 464, + 1893, + 286, + 262, + 1578, + 1829, + 11, + 290, + 262, + 1893, + 286, + 262, + 1578, + 7526, + 11, + 423, + 587, + 287, + 262, + 2635, + ] # The president of the United States, and the president of the United Kingdom, have been in the White + + output_ids = model.generate( + input_ids, + do_sample=False, + bos_token_id=self.special_tokens["bos_token_id"], + eos_token_ids=self.special_tokens["eos_token_id"], + ) + + self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids) diff --git a/tests/test_modeling_tf_openai_gpt.py b/tests/test_modeling_tf_openai_gpt.py index b825c94fca..b8bf74f88a 100644 --- a/tests/test_modeling_tf_openai_gpt.py +++ b/tests/test_modeling_tf_openai_gpt.py @@ -39,6 +39,9 @@ class TFOpenAIGPTModelTest(TFModelTesterMixin, unittest.TestCase): all_model_classes = ( (TFOpenAIGPTModel, TFOpenAIGPTLMHeadModel, TFOpenAIGPTDoubleHeadsModel) if is_tf_available() else () ) + all_generative_model_classes = ( + (TFOpenAIGPTLMHeadModel,) if is_tf_available() else () + ) # TODO (PVP): Add Double HeadsModel when generate() function is changed accordingly class TFOpenAIGPTModelTester(object): def __init__( diff --git a/tests/test_modeling_tf_transfo_xl.py b/tests/test_modeling_tf_transfo_xl.py index f94f2032a2..f2d8e58362 100644 --- a/tests/test_modeling_tf_transfo_xl.py +++ b/tests/test_modeling_tf_transfo_xl.py @@ -37,6 +37,8 @@ if is_tf_available(): class TFTransfoXLModelTest(TFModelTesterMixin, unittest.TestCase): all_model_classes = (TFTransfoXLModel, TFTransfoXLLMHeadModel) if is_tf_available() else () + all_generative_model_classes = () if is_tf_available() else () + # TODO: add this test when TFTransfoXLLMHead has a linear output layer implemented test_pruning = False test_torchscript = False test_resize_embeddings = False @@ -62,6 +64,7 @@ class TFTransfoXLModelTest(TFModelTesterMixin, unittest.TestCase): num_hidden_layers=5, scope=None, seed=1, + eos_token_id=0, ): self.parent = parent self.batch_size = batch_size @@ -82,6 +85,7 @@ class TFTransfoXLModelTest(TFModelTesterMixin, unittest.TestCase): self.num_hidden_layers = num_hidden_layers self.scope = scope self.seed = seed + self.eos_token_id = eos_token_id def prepare_config_and_inputs(self): input_ids_1 = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) @@ -103,6 +107,7 @@ class TFTransfoXLModelTest(TFModelTesterMixin, unittest.TestCase): d_inner=self.d_inner, div_val=self.div_val, n_layer=self.num_hidden_layers, + eos_token_ids=self.eos_token_id, ) return (config, input_ids_1, input_ids_2, lm_labels) diff --git a/tests/test_modeling_tf_xlm.py b/tests/test_modeling_tf_xlm.py index 53719f63f4..ebadd074e6 100644 --- a/tests/test_modeling_tf_xlm.py +++ b/tests/test_modeling_tf_xlm.py @@ -43,6 +43,9 @@ class TFXLMModelTest(TFModelTesterMixin, unittest.TestCase): if is_tf_available() else () ) + all_generative_model_classes = ( + (TFXLMWithLMHeadModel,) if is_tf_available() else () + ) # TODO (PVP): Check other models whether language generation is also applicable class TFXLMModelTester(object): def __init__( @@ -75,6 +78,7 @@ class TFXLMModelTest(TFModelTesterMixin, unittest.TestCase): summary_type="last", use_proj=True, scope=None, + bos_token_id=0, ): self.parent = parent self.batch_size = batch_size @@ -105,6 +109,7 @@ class TFXLMModelTest(TFModelTesterMixin, unittest.TestCase): self.num_labels = num_labels self.num_choices = num_choices self.scope = scope + self.bos_token_id = bos_token_id def prepare_config_and_inputs(self): input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) @@ -145,6 +150,7 @@ class TFXLMModelTest(TFModelTesterMixin, unittest.TestCase): initializer_range=self.initializer_range, summary_type=self.summary_type, use_proj=self.use_proj, + bos_token_id=self.bos_token_id, ) return ( diff --git a/tests/test_modeling_tf_xlnet.py b/tests/test_modeling_tf_xlnet.py index 65c83395e5..687fe01575 100644 --- a/tests/test_modeling_tf_xlnet.py +++ b/tests/test_modeling_tf_xlnet.py @@ -51,6 +51,9 @@ class TFXLNetModelTest(TFModelTesterMixin, unittest.TestCase): if is_tf_available() else () ) + all_generative_model_classes = ( + (TFXLNetLMHeadModel,) if is_tf_available() else () + ) # TODO (PVP): Check other models whether language generation is also applicable test_pruning = False class TFXLNetModelTester(object): @@ -77,6 +80,9 @@ class TFXLNetModelTest(TFModelTesterMixin, unittest.TestCase): initializer_range=0.05, seed=1, type_vocab_size=2, + bos_token_id=1, + eos_token_id=2, + pad_token_id=5, ): self.parent = parent self.batch_size = batch_size @@ -100,6 +106,9 @@ class TFXLNetModelTest(TFModelTesterMixin, unittest.TestCase): self.seed = seed self.type_vocab_size = type_vocab_size self.type_sequence_label_size = type_sequence_label_size + self.bos_token_id = bos_token_id + self.pad_token_id = pad_token_id + self.eos_token_id = eos_token_id def prepare_config_and_inputs(self): input_ids_1 = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) @@ -139,6 +148,9 @@ class TFXLNetModelTest(TFModelTesterMixin, unittest.TestCase): bi_data=self.bi_data, initializer_range=self.initializer_range, num_labels=self.type_sequence_label_size, + bos_token_id=self.bos_token_id, + pad_token_id=self.pad_token_id, + eos_token_id=self.eos_token_id, ) return (