From 975dd2bbbcd4e8bdaf07c275c090d218d88c7c12 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Tue, 31 May 2022 14:06:44 +0100 Subject: [PATCH] TF: GPT-2 generation supports left-padding (#17426) * TF GPT-2 now properly works with left padding * throw a warning when eos token == pad token and there is no attention mask --- src/transformers/generation_tf_utils.py | 114 +++++++++--------- src/transformers/generation_utils.py | 10 +- .../models/gpt2/modeling_tf_gpt2.py | 27 ++--- tests/models/gpt2/test_modeling_tf_gpt2.py | 81 +++++++++---- 4 files changed, 136 insertions(+), 96 deletions(-) diff --git a/src/transformers/generation_tf_utils.py b/src/transformers/generation_tf_utils.py index 04ae9cc31d..f27a772c08 100644 --- a/src/transformers/generation_tf_utils.py +++ b/src/transformers/generation_tf_utils.py @@ -1498,8 +1498,14 @@ class TFGenerationMixin: ) if pad_token_id is None and eos_token_id is not None: + if attention_mask is None: + logger.warning( + "The attention mask and the pad token id were not set. As a consequence, you may observe " + "unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results." + ) logger.warning(f"Setting `pad_token_id` to {eos_token_id} (first `eos_token_id`) to generate sequence") pad_token_id = eos_token_id + if min_length is not None and min_length > max_length: raise ValueError( f"Unfeasable length constraints: the minimum length ({min_length}) is larger than the maximum " @@ -1525,7 +1531,9 @@ class TFGenerationMixin: requires_attention_mask = "encoder_outputs" not in model_kwargs if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask: - model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(input_ids, pad_token_id) + model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation( + input_ids, pad_token_id, eos_token_id + ) # 4. Prepare model inputs which will be used for auto-regressive generation if self.config.is_encoder_decoder: @@ -1653,12 +1661,17 @@ class TFGenerationMixin: def _prepare_attention_mask_for_generation( self, inputs: tf.Tensor, - pad_token_id: int, + pad_token_id: Optional[int], + eos_token_id: Optional[int], ) -> tf.Tensor: is_input_ids = len(inputs.shape) == 2 and inputs.dtype in (tf.int32, tf.int64) is_pad_token_in_inputs = (pad_token_id is not None) and tf.math.reduce_any(inputs == pad_token_id) + is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or ( + (eos_token_id is not None) and (pad_token_id != eos_token_id) + ) + # Check if input is input_ids and padded -> only then is attention_mask defined - if is_input_ids and is_pad_token_in_inputs: + if is_input_ids and is_pad_token_in_inputs and is_pad_token_not_equal_to_eos_token_id: return tf.cast(tf.math.not_equal(inputs, pad_token_id), dtype=tf.int32) else: return tf.ones(inputs.shape[:2], dtype=tf.int32) @@ -1954,6 +1967,7 @@ class TFGenerationMixin: # 1. init greedy_search values logits_processor = logits_processor if logits_processor is not None else TFLogitsProcessorList() + max_length = max_length if max_length is not None else self.config.max_length pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id output_scores = output_scores if output_scores is not None else self.config.output_scores @@ -1973,10 +1987,9 @@ class TFGenerationMixin: decoder_hidden_states = [] if (return_dict_in_generate and output_hidden_states) else None # 3. init tensors to use for "xla-compileable" generate function - # define bsz, seq_length - batch_size, seq_length = input_ids.shape + batch_size, cur_len = input_ids.shape - # initialize `generated`, `finished_sequences`, and `current_pos` + # initialize `generated` (pre-populated with `pad_token_id`), `finished_sequences` generated = tf.TensorArray( element_shape=(batch_size,), dtype=tf.int32, @@ -1984,25 +1997,26 @@ class TFGenerationMixin: size=max_length, clear_after_read=False, ) + if pad_token_id: # ignores the cases when it is 0 or None + for i in range(max_length): + generated = generated.write(i, tf.broadcast_to(pad_token_id, (batch_size,))) # write prompt to generated - for i in range(seq_length): + for i in range(cur_len): generated = generated.write(i, input_ids[:, i]) finished_sequences = tf.zeros((batch_size,), dtype=tf.bool) - current_pos = tf.ones(shape=(1,), dtype=tf.int32) * seq_length # 4. define "xla-compile-able" stop-condition and auto-regressive function # define condition fn - def greedy_search_cond_fn(generated, finished_sequences, next_tokens, current_pos, model_kwargs): + def greedy_search_cond_fn(generated, finished_sequences, next_tokens, cur_len, model_kwargs): """state termination condition fn.""" return ~tf.reduce_all(finished_sequences) # define condition fn - def greedy_search_body_fn(generated, finished_sequences, next_tokens, current_pos, model_kwargs): + def greedy_search_body_fn(generated, finished_sequences, next_tokens, cur_len, model_kwargs): """state update fn.""" - # TODO(pvp, Joao) - `use_xla` can be removed here as soon as `position_ids` are corrected for the non-xla case in gpt2's `prepare_inputs_for_generation`. - model_inputs = self.prepare_inputs_for_generation(next_tokens, use_xla=use_xla, **model_kwargs) + model_inputs = self.prepare_inputs_for_generation(next_tokens, **model_kwargs) # forward pass to get next token logits outputs = self( **model_inputs, @@ -2029,13 +2043,8 @@ class TFGenerationMixin: decoder_hidden_states.append(outputs.hidden_states) # pre-process distribution - # TODO(pvp, joao, matt) - all the logits processors need to be adapted - # to be XLA compatible - input_ids = None - if not use_xla: - input_ids = tf.reshape(generated.concat(), (-1, batch_size)) - input_ids = tf.transpose(input_ids[: current_pos[0]]) - next_tokens_scores = logits_processor(input_ids, next_token_logits, current_pos[0]) + input_ids = tf.transpose(tf.reshape(generated.concat(), (-1, batch_size))) + next_tokens_scores = logits_processor(input_ids, next_token_logits, cur_len) # argmax next_tokens = tf.argmax(next_tokens_scores, axis=-1, output_type=tf.int32) @@ -2047,16 +2056,14 @@ class TFGenerationMixin: next_tokens = next_tokens * unfinished_seq + pad_token_id * (1 - unfinished_seq) finished_sequences = finished_sequences | (next_tokens == eos_token_id) - # update `generated` and `current_pos` - generated = generated.write(current_pos[0], next_tokens) + # update `generated` and `cur_len` + generated = generated.write(cur_len, next_tokens) next_tokens = next_tokens[:, None] - current_pos += 1 + cur_len += 1 # update model_kwargs if use_xla: - model_kwargs = self._update_model_kwargs_for_xla_generation( - outputs, model_kwargs, current_pos, max_length - ) + model_kwargs = self._update_model_kwargs_for_xla_generation(outputs, model_kwargs, cur_len, max_length) else: model_kwargs = self._update_model_kwargs_for_generation( outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder @@ -2067,24 +2074,24 @@ class TFGenerationMixin: model_kwargs.pop("past", None) next_tokens = tf.reshape(generated.concat(), (-1, batch_size)) - next_tokens = tf.transpose(next_tokens[: current_pos[0]]) + next_tokens = tf.transpose(next_tokens[:cur_len]) - return generated, finished_sequences, next_tokens, current_pos, model_kwargs + return generated, finished_sequences, next_tokens, cur_len, model_kwargs # 5. run generation # 1st generation step has to be run before to initialize `past` - generated, finished_sequences, next_tokens, current_pos, model_kwargs = greedy_search_body_fn( - generated, finished_sequences, input_ids, current_pos, model_kwargs + generated, finished_sequences, next_tokens, cur_len, model_kwargs = greedy_search_body_fn( + generated, finished_sequences, input_ids, cur_len, model_kwargs ) # 2-to-n generation steps can then be run in autoregressive fashion # only in case 1st generation step does NOT yield EOS token though - if greedy_search_cond_fn(generated, finished_sequences, next_tokens, current_pos, model_kwargs): - maximum_iterations = max_length - seq_length - 1 - generated, _, _, current_pos, _ = tf.while_loop( + if greedy_search_cond_fn(generated, finished_sequences, next_tokens, cur_len, model_kwargs): + maximum_iterations = max_length - cur_len + generated, _, _, cur_len, _ = tf.while_loop( greedy_search_cond_fn, greedy_search_body_fn, - (generated, finished_sequences, next_tokens, current_pos, model_kwargs), + (generated, finished_sequences, next_tokens, cur_len, model_kwargs), maximum_iterations=maximum_iterations, ) @@ -2093,7 +2100,7 @@ class TFGenerationMixin: if not use_xla: # cut for backward compatibility - output_ids = output_ids[:, : current_pos[0]] + output_ids = output_ids[:, :cur_len] if return_dict_in_generate: if self.config.is_encoder_decoder: @@ -2231,6 +2238,7 @@ class TFGenerationMixin: logits_processor = logits_processor if logits_processor is not None else TFLogitsProcessorList() logits_warper = logits_warper if logits_warper is not None else TFLogitsProcessorList() + max_length = max_length if max_length is not None else self.config.max_length pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id output_scores = output_scores if output_scores is not None else self.config.output_scores @@ -2250,10 +2258,9 @@ class TFGenerationMixin: decoder_hidden_states = [] if (return_dict_in_generate and output_hidden_states) else None # 3. init tensors to use for "xla-compileable" generate function - # define bsz, seq_length batch_size, cur_len = input_ids.shape - # initialize `generated`, `finished_sequences` + # initialize `generated` (pre-populated with `pad_token_id`), `finished_sequences` generated = tf.TensorArray( element_shape=(batch_size,), dtype=tf.int32, @@ -2261,19 +2268,22 @@ class TFGenerationMixin: size=max_length, clear_after_read=False, ) - finished_sequences = tf.zeros((batch_size,), dtype=tf.bool) + if pad_token_id: # ignores the cases when it is 0 or None + for i in range(max_length): + generated = generated.write(i, tf.broadcast_to(pad_token_id, (batch_size,))) # write prompt to generated for i in range(cur_len): generated = generated.write(i, input_ids[:, i]) + finished_sequences = tf.zeros((batch_size,), dtype=tf.bool) + # 4. define "xla-compile-able" stop-condition and auto-regressive function def sample_cond_fn(generated, finished_sequences, next_tokens, cur_len, model_kwargs): return ~tf.reduce_all(finished_sequences) def sample_body_fn(generated, finished_sequences, next_tokens, cur_len, model_kwargs): - # TODO(pvp, Joao) - `use_xla` can be removed here as soon as `position_ids` are corrected for the non-xla case in gpt2's `prepare_inputs_for_generation`. - model_inputs = self.prepare_inputs_for_generation(next_tokens, use_xla=use_xla, **model_kwargs) + model_inputs = self.prepare_inputs_for_generation(next_tokens, **model_kwargs) # forward pass to get next token logits outputs = self( **model_inputs, @@ -2300,12 +2310,7 @@ class TFGenerationMixin: decoder_hidden_states.append(outputs.hidden_states) # pre-process distribution - # TODO(pvp, joao, matt) - all the logits processors/wrappers need to be adapted - # to be XLA compatible - input_ids = None - if not use_xla: - input_ids = tf.reshape(generated.concat(), (-1, batch_size)) - input_ids = tf.transpose(input_ids[:cur_len]) + input_ids = tf.transpose(tf.reshape(generated.concat(), (-1, batch_size))) next_tokens_scores = logits_processor(input_ids, next_token_logits, cur_len) next_tokens_scores = logits_warper(input_ids, next_tokens_scores, cur_len) @@ -2359,7 +2364,7 @@ class TFGenerationMixin: # 2-to-n generation steps can then be run in autoregressive fashion # only in case 1st generation step does NOT yield EOS token though if sample_cond_fn(generated, finished_sequences, next_tokens, cur_len, model_kwargs): - maximum_iterations = max_length - cur_len - 1 + maximum_iterations = max_length - cur_len generated, _, _, cur_len, _ = tf.while_loop( sample_cond_fn, sample_body_fn, @@ -2613,12 +2618,13 @@ class TFGenerationMixin: size=max_length, clear_after_read=False, ) - for i in range(max_length): - sequences = sequences.write(i, tf.broadcast_to(pad_token_id, (batch_size, num_beams))) - running_sequences = running_sequences.write(i, tf.broadcast_to(pad_token_id, (batch_size, num_beams))) - intermediary_running_sequences = intermediary_running_sequences.write( - i, tf.broadcast_to(pad_token_id, (batch_size, num_beams * 2)) - ) + if pad_token_id: # ignores the cases when it is 0 or None + for i in range(max_length): + sequences = sequences.write(i, tf.broadcast_to(pad_token_id, (batch_size, num_beams))) + running_sequences = running_sequences.write(i, tf.broadcast_to(pad_token_id, (batch_size, num_beams))) + intermediary_running_sequences = intermediary_running_sequences.write( + i, tf.broadcast_to(pad_token_id, (batch_size, num_beams * 2)) + ) # write prompt to running_sequences for i in range(cur_len): @@ -2699,9 +2705,7 @@ class TFGenerationMixin: (0, 0, cur_len - input_ids_length), (batch_size, num_beams, input_ids_length), ) - model_inputs = self.prepare_inputs_for_generation( - flatten_beam_dim(input_token), use_xla=use_xla, **model_kwargs - ) + model_inputs = self.prepare_inputs_for_generation(flatten_beam_dim(input_token), **model_kwargs) model_outputs = self( **model_inputs, return_dict=True, diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index 7b9968de12..be34b742fc 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -490,8 +490,8 @@ class GenerationMixin: def _prepare_attention_mask_for_generation( self, inputs: torch.Tensor, - pad_token_id: int, - eos_token_id: int, + pad_token_id: Optional[int], + eos_token_id: Optional[int], ) -> torch.LongTensor: is_input_ids = len(inputs.shape) == 2 and inputs.dtype in [torch.int, torch.long] is_pad_token_in_inputs = (pad_token_id is not None) and (pad_token_id in inputs) @@ -1137,7 +1137,11 @@ class GenerationMixin: eos_token_id = self.config.decoder.eos_token_id if pad_token_id is None and eos_token_id is not None: - # special case if pad_token_id is not defined + if model_kwargs.get("attention_mask", None) is None: + logger.warning( + "The attention mask and the pad token id were not set. As a consequence, you may observe " + "unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results." + ) logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.") pad_token_id = eos_token_id diff --git a/src/transformers/models/gpt2/modeling_tf_gpt2.py b/src/transformers/models/gpt2/modeling_tf_gpt2.py index 2422af5ebc..b3d1ad0484 100644 --- a/src/transformers/models/gpt2/modeling_tf_gpt2.py +++ b/src/transformers/models/gpt2/modeling_tf_gpt2.py @@ -813,25 +813,21 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel, TFCausalLanguageModelingLoss): def set_output_embeddings(self, value): self.set_input_embeddings(value) - def prepare_inputs_for_generation(self, inputs, past=None, use_cache=None, use_xla=False, **kwargs): - # TODO: (Joao) after the TF generator is complete, update GPT2 TF generation to match PT's. NB -- some GPT2 - # tests will need to be fixed after the change - + def prepare_inputs_for_generation(self, inputs, past=None, use_cache=None, **kwargs): + token_type_ids = kwargs.get("token_type_ids", None) # only last token for inputs_ids if past is defined in kwargs if past: inputs = tf.expand_dims(inputs[:, -1], -1) + if token_type_ids is not None: + token_type_ids = tf.expand_dims(token_type_ids[:, -1], -1) - # TODO(pvp, Joao) - this `if use_xla` statement can be removed, but is left - # for a future PR to not change too many things for now. - # All statements in this if case apply for both xla and non-xla (as they already do in PyTorch) - position_ids = None - attention_mask = None - if use_xla: - attention_mask = kwargs.get("attention_mask", None) - if past is not None and attention_mask is not None: - position_ids = tf.reduce_sum(attention_mask, axis=1, keepdims=True) - 1 - elif attention_mask is not None: - position_ids = tf.math.cumsum(attention_mask, axis=1, exclusive=True) + position_ids = kwargs.get("position_ids", None) + attention_mask = kwargs.get("attention_mask", None) + + if attention_mask is not None and position_ids is None: + position_ids = tf.math.cumsum(attention_mask, axis=-1, exclusive=True) + if past: + position_ids = tf.expand_dims(position_ids[:, -1], -1) return { "input_ids": inputs, @@ -839,6 +835,7 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel, TFCausalLanguageModelingLoss): "position_ids": position_ids, "past": past, "use_cache": use_cache, + "token_type_ids": token_type_ids, } def _update_model_kwargs_for_xla_generation(self, outputs, model_kwargs, current_pos, max_length): diff --git a/tests/models/gpt2/test_modeling_tf_gpt2.py b/tests/models/gpt2/test_modeling_tf_gpt2.py index a032e33500..93b48ce8f2 100644 --- a/tests/models/gpt2/test_modeling_tf_gpt2.py +++ b/tests/models/gpt2/test_modeling_tf_gpt2.py @@ -456,7 +456,7 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase): tokenizer.padding_side = "left" sentences = ["Today is a beautiful day and", "Yesterday was"] - input_ids = tokenizer(sentences, return_tensors="tf", padding=True).input_ids + input_ids = tokenizer(sentences, return_tensors="tf", padding=True) generation_kwargs = { "bad_words_ids": [tokenizer("is").input_ids, tokenizer("angry about").input_ids], @@ -465,12 +465,12 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase): "repetition_penalty": 1.3, } - output_ids = model.generate(input_ids, **generation_kwargs) + output_ids = model.generate(**input_ids, **generation_kwargs) output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True) expected_output_string = [ "Today is a beautiful day and I am so happy to be able take part in this amazing event.", - "Yesterday was a very busy day for the first time since I started writing this post", + "Yesterday was a very interesting time for the world to see how much of this is", ] self.assertListEqual(output_strings, expected_output_string) @@ -483,7 +483,7 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase): tokenizer.padding_side = "left" sentences = ["Today is a beautiful day and", "Yesterday was"] - input_ids = tokenizer(sentences, return_tensors="tf", padding=True).input_ids + input_ids = tokenizer(sentences, return_tensors="tf", padding=True) generation_kwargs = { "do_sample": True, @@ -498,13 +498,13 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase): # forces the generation to happen on CPU, to avoid GPU-related quirks with tf.device(":/CPU:0"): - output_ids = model.generate(input_ids, **generation_kwargs) + output_ids = model.generate(**input_ids, **generation_kwargs) output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True) expected_output_string = [ - "Today is a beautiful day and we will make you feel very hot/terrific in all", - "Yesterday was another solid success as news coverage became standard American domestic television hit.", + "Today is a beautiful day and we will make you feel very hot/terrific in all your", + "Yesterday was known by national television networks as Le Big Show or Wild Dog Jeopard", ] self.assertListEqual(output_strings, expected_output_string) @@ -517,7 +517,7 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase): tokenizer.padding_side = "left" sentences = ["Today is a beautiful day and", "Yesterday was"] - input_ids = tokenizer(sentences, return_tensors="tf", padding=True).input_ids + input_ids = tokenizer(sentences, return_tensors="tf", padding=True) generation_kwargs = { "bad_words_ids": [tokenizer("is").input_ids, tokenizer("angry about").input_ids], @@ -526,37 +526,69 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase): "num_beams": 2, } - output_ids = model.generate(input_ids, **generation_kwargs) + output_ids = model.generate(**input_ids, **generation_kwargs) output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True) expected_output_string = [ "Today is a beautiful day and a great day for all of us.\n\nI’m", - "Yesterday was the first day of the year for the second time in a row,", + "Yesterday was the first time that a person has been arrested in the United States for", ] self.assertListEqual(output_strings, expected_output_string) + @slow + def test_lm_generate_distilgpt2_left_padding(self): + """Tests that the generated text is the same, regarless of left padding""" + model = TFGPT2LMHeadModel.from_pretrained("distilgpt2") + tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2") + + tokenizer.pad_token = tokenizer.eos_token + tokenizer.padding_side = "left" + + generation_kwargs = { + "bad_words_ids": [tokenizer("is").input_ids, tokenizer("angry about").input_ids], + "no_repeat_ngram_size": 2, + "do_sample": False, + "repetition_penalty": 1.3, + } + expected_output_string = ( + "Today is a beautiful day and I am so happy to be able take part in this amazing event." + ) + + sentences = ["Today is a beautiful day and"] + input_ids = tokenizer(sentences, return_tensors="tf", padding=True) + # using default length + output_ids = model.generate(**input_ids, **generation_kwargs) + output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True) + self.assertEqual(output_strings[0], expected_output_string) + + sentences = ["Today is a beautiful day and", "This is a very long input that we absolutely don't care about"] + input_ids = tokenizer(sentences, return_tensors="tf", padding=True) + # longer max length to capture the full length (remember: it is left padded) + output_ids = model.generate(**input_ids, **generation_kwargs, max_length=27) + output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True) + self.assertEqual(output_strings[0], expected_output_string) + @slow def test_lm_generate_gpt2_greedy_xla(self): - # TODO (Joao): convert this to an example with a batch size>1 with different input lengths that works (and fix - # the underlying problem) model = TFGPT2LMHeadModel.from_pretrained("gpt2") tokenizer = GPT2Tokenizer.from_pretrained("gpt2") tokenizer.pad_token = tokenizer.eos_token tokenizer.padding_side = "left" - sentences = ["The dog"] + sentences = ["The dog", "The flying machine"] expected_output_strings = [ - "The dog was found in a field near the intersection of West and West Streets.\n\nThe dog", + "The dog was found in a field near the intersection of West and West Streets.\n\nThe", + "The flying machine is a small, lightweight, and lightweight aircraft that can be used for any type of", ] - input_ids = tokenizer(sentences, return_tensors="tf", padding=True).input_ids + input_ids = tokenizer(sentences, return_tensors="tf", padding=True) - output_ids = model.generate(input_ids, do_sample=False) + output_ids = model.generate(**input_ids, do_sample=False) output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True) self.assertListEqual(output_strings, expected_output_strings) xla_generate = tf.function(model.generate, jit_compile=True) - output_ids = xla_generate(input_ids, do_sample=False) + output_ids = xla_generate(**input_ids, do_sample=False) output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True) self.assertListEqual(output_strings, expected_output_strings) @@ -574,21 +606,24 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase): tokenizer.pad_token = tokenizer.eos_token tokenizer.padding_side = "left" - sentence = ["The dog"] + sentence = ["The dog", "The flying machine"] expected_output_string = [ "The dog owner asked why did our vet decide there needed to be extra ventilation inside because most" - " puppies" + " puppies", + "The flying machine was made by an artist who found it difficult to control it as it did not use", ] expected_output_string_xla = [ - "The dog has been named in connection with the murder of a 20-year-old man in!" + "The dog has been named in connection with the murder of a 20-year-old man in", + "The flying machine is a new and improved system to operate and operate a new system and system " + "system system", ] - input_ids = tokenizer(sentence, return_tensors="tf", padding=True).input_ids + input_ids = tokenizer(sentence, return_tensors="tf", padding=True) - output_ids = model.generate(input_ids, do_sample=True, seed=[7, 0]) + output_ids = model.generate(**input_ids, do_sample=True, seed=[7, 0]) output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True) self.assertListEqual(output_strings, expected_output_string) xla_generate = tf.function(model.generate, jit_compile=True) - output_ids = xla_generate(input_ids, do_sample=True, seed=[7, 0]) + output_ids = xla_generate(**input_ids, do_sample=True, seed=[7, 0]) output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True) self.assertListEqual(output_strings, expected_output_string_xla)