From b4ddd2677c9072f39267c8b3bf9b33a7a52d108f Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Mon, 18 Apr 2022 10:58:24 +0100 Subject: [PATCH] TF generate refactor - XLA sample (#16713) --- src/transformers/generation_tf_utils.py | 208 ++++++++++++++++-------- tests/gpt2/test_modeling_tf_gpt2.py | 34 ++-- tests/t5/test_modeling_tf_t5.py | 33 +++- 3 files changed, 187 insertions(+), 88 deletions(-) diff --git a/src/transformers/generation_tf_utils.py b/src/transformers/generation_tf_utils.py index 918f9169ab..a5043c0ec5 100644 --- a/src/transformers/generation_tf_utils.py +++ b/src/transformers/generation_tf_utils.py @@ -346,6 +346,8 @@ class TFGenerationMixin: A class containing all of the functions supporting generation, to be used as a mixin in [`TFPreTrainedModel`]. """ + seed_generator = tf.random.Generator.from_non_deterministic_state() + def prepare_inputs_for_generation(self, inputs, **kwargs): """ Implement in subclasses of [`TFPreTrainedModel`] for custom behavior to prepare inputs in the generate method. @@ -585,6 +587,7 @@ class TFGenerationMixin: attention_mask=attention_mask, decoder_start_token_id=decoder_start_token_id, use_cache=use_cache, + seed=model_kwargs.pop("seed", None), output_scores=output_scores, output_attentions=output_attentions, output_hidden_states=output_hidden_states, @@ -1288,6 +1291,7 @@ class TFGenerationMixin: attention_mask=None, decoder_start_token_id=None, use_cache=None, + seed=None, output_scores=None, output_attentions=None, output_hidden_states=None, @@ -1365,6 +1369,9 @@ class TFGenerationMixin: use_cache (`bool`, *optional*, defaults to `True`): Whether or not the model should use the past last key/values attentions (if applicable to the model) to speed up decoding. + seed (`List[int]`, *optional*): + Random seed to control sampling, containing two integers, used when `do_sample` is `True`. See the + `seed` argument from stateless functions in `tf.random`. output_attentions (`bool`, *optional*, defaults to `False`): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more details. @@ -1590,6 +1597,7 @@ class TFGenerationMixin: max_length=max_length, pad_token_id=pad_token_id, eos_token_id=eos_token_id, + seed=seed, output_scores=output_scores, return_dict_in_generate=return_dict_in_generate, **model_kwargs, @@ -1723,7 +1731,7 @@ class TFGenerationMixin: **model_kwargs, ) -> Tuple[tf.Tensor, Dict[str, Any]]: expanded_return_idx = tf.reshape( - tf.tile(tf.reshape(tf.range(input_ids.shape[0]), (-1, 1)), (1, expand_size)), (-1) + tf.tile(tf.reshape(tf.range(input_ids.shape[0]), (-1, 1)), (1, expand_size)), (-1,) ) input_ids = tf.gather(input_ids, expanded_return_idx, axis=0) @@ -2123,6 +2131,7 @@ class TFGenerationMixin: max_length: Optional[int] = None, pad_token_id: Optional[int] = None, eos_token_id: Optional[int] = None, + seed: Optional[Tuple[int, int]] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_scores: Optional[bool] = None, @@ -2149,6 +2158,9 @@ class TFGenerationMixin: The id of the *padding* token. eos_token_id (`int`, *optional*): The id of the *end-of-sequence* token. + seed (`List[int]`, *optional*): + Random seed to control sampling, containing two integers, used when `do_sample` is `True`. See the + `seed` argument from stateless functions in `tf.random`. output_attentions (`bool`, *optional*, defaults to `False`): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more details. @@ -2210,7 +2222,7 @@ class TFGenerationMixin: >>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True)) ```""" - # init values + # 1. init greedy_search values logits_processor = logits_processor if logits_processor is not None else TFLogitsProcessorList() logits_warper = logits_warper if logits_warper is not None else TFLogitsProcessorList() @@ -2224,97 +2236,155 @@ class TFGenerationMixin: return_dict_in_generate = ( return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate ) + use_xla = not tf.executing_eagerly() - # init attention / hidden states / scores tuples - scores = () if (return_dict_in_generate and output_scores) else None - decoder_attentions = () if (return_dict_in_generate and output_attentions) else None - cross_attentions = () if (return_dict_in_generate and output_attentions) else None - decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None + # 2. init `attentions`, `hidden_states`, and `scores` tuples + scores = [] if (return_dict_in_generate and output_scores) else None + decoder_attentions = [] if (return_dict_in_generate and output_attentions) else None + cross_attentions = [] if (return_dict_in_generate and output_attentions) else None + decoder_hidden_states = [] if (return_dict_in_generate and output_hidden_states) else None - # if model is an encoder-decoder, retrieve encoder attention weights and hidden states - if return_dict_in_generate and self.config.is_encoder_decoder: - encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None - encoder_hidden_states = ( - model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None - ) + # 3. init tensors to use for "xla-compileable" generate function + # define bsz, seq_length + batch_size, cur_len = input_ids.shape - # keep track of which sequences are already finished - unfinished_sequences = tf.ones_like(input_ids[:, 0]) - cur_len = input_ids.shape[-1] + # initialize `generated`, `finished_sequences` + generated = tf.TensorArray( + element_shape=(batch_size,), + dtype=tf.int32, + dynamic_size=False, + size=max_length, + clear_after_read=False, + ) + finished_sequences = tf.zeros((batch_size,), dtype=tf.bool) - while cur_len < max_length: - # prepare model inputs - model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + # write prompt to generated + for i in range(cur_len): + generated = generated.write(i, input_ids[:, i]) - # forward pass to get next token + # 4. define "xla-compile-able" stop-condition and auto-regressive function + def sample_cond_fn(generated, finished_sequences, next_tokens, cur_len, model_kwargs): + return ~tf.reduce_all(finished_sequences) + + def sample_body_fn(generated, finished_sequences, next_tokens, cur_len, model_kwargs): + # TODO(pvp, Joao) - `use_xla` can be removed here as soon as `position_ids` are corrected for the non-xla case in gpt2's `prepare_inputs_for_generation`. + model_inputs = self.prepare_inputs_for_generation(next_tokens, use_xla=use_xla, **model_kwargs) + # forward pass to get next token logits outputs = self( **model_inputs, return_dict=True, output_attentions=output_attentions, output_hidden_states=output_hidden_states, ) - - next_token_logits = outputs.logits[:, -1, :] - - # pre-process distribution - next_token_scores = logits_processor(input_ids, next_token_logits, cur_len=cur_len) - next_token_scores = logits_warper(input_ids, next_token_scores) + next_token_logits = outputs.logits[:, -1] # Store scores, attentions and hidden_states when required - if return_dict_in_generate: + if not use_xla and return_dict_in_generate: if output_scores: - scores += (next_token_scores,) - if output_attentions: - decoder_attentions += ( - (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) - ) + scores.append(next_token_logits) + if output_attentions and self.config.is_encoder_decoder: + decoder_attentions.append(outputs.decoder_attentions) + elif output_attentions and not self.config.is_encoder_decoder: + decoder_attentions.append(outputs.attentions) if self.config.is_encoder_decoder: - cross_attentions += (outputs.cross_attentions,) + cross_attentions.append(outputs.cross_attentions) - if output_hidden_states: - decoder_hidden_states += ( - (outputs.decoder_hidden_states,) - if self.config.is_encoder_decoder - else (outputs.hidden_states,) - ) + if output_hidden_states and self.config.is_encoder_decoder: + decoder_hidden_states.append(outputs.decoder_hidden_states) + elif output_hidden_states and self.config.is_encoder_decoder: + decoder_hidden_states.append(outputs.hidden_states) + + # pre-process distribution + # TODO(pvp, joao, matt) - all the logits processors/wrappers need to be adapted + # to be XLA compatible + input_ids = None + if not use_xla: + input_ids = tf.reshape(generated.concat(), (-1, batch_size)) + input_ids = tf.transpose(input_ids[:cur_len]) + next_tokens_scores = logits_processor(input_ids, next_token_logits, cur_len=cur_len) + next_tokens_scores = logits_warper(input_ids, next_tokens_scores) # sample + if seed is not None: + sample_seed = seed + else: + sample_seed = tf.cast(self.seed_generator.make_seeds(count=1)[:, 0], dtype=tf.int32) next_tokens = tf.squeeze( - tf.random.categorical(logits=next_token_scores, num_samples=1, dtype=tf.int32), axis=1 + tf.random.stateless_categorical( + logits=next_tokens_scores, num_samples=1, seed=sample_seed, dtype=tf.int32 + ), + axis=1, ) - # finished sentences should have their next token be a padding token if eos_token_id is not None: if pad_token_id is None: raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") - next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) + unfinished_seq = 1 - tf.cast(finished_sequences, tf.int32) + next_tokens = next_tokens * unfinished_seq + pad_token_id * (1 - unfinished_seq) + finished_sequences = finished_sequences | (next_tokens == eos_token_id) - # update generated ids, model inputs, and length for next step - input_ids = tf.concat([input_ids, next_tokens[:, None]], axis=-1) - model_kwargs = self._update_model_kwargs_for_generation( - outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder - ) - cur_len = cur_len + 1 + # update `generated` and `cur_len` + generated = generated.write(cur_len, next_tokens) + next_tokens = next_tokens[:, None] + cur_len += 1 - # if eos_token was found in one sentence, set sentence to finished - if eos_token_id is not None: - eos_in_sents = next_tokens == eos_token_id - # if sentence is unfinished and the token to add is eos - is_sents_unfinished_and_token_to_add_is_eos = tf.math.multiply( - unfinished_sequences, tf.cast(eos_in_sents, tf.int32) + # update model_kwargs + if use_xla: + model_kwargs = self._update_model_kwargs_for_xla_generation(outputs, model_kwargs, cur_len, max_length) + else: + model_kwargs = self._update_model_kwargs_for_generation( + outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder ) + # if we don't cache past key values we need the whole input + if model_kwargs.get("past", None) is None: + # let's throw out `past` since we don't want `None` tensors + model_kwargs.pop("past", None) - # unfinished_sequences is set to zero if eos in sentence - unfinished_sequences -= is_sents_unfinished_and_token_to_add_is_eos + next_tokens = tf.reshape(generated.concat(), (-1, batch_size)) + next_tokens = tf.transpose(next_tokens[:cur_len]) - # stop when each sentence is finished, or if we exceed the maximum length - if tf.math.reduce_max(unfinished_sequences) == 0: - break + return generated, finished_sequences, next_tokens, cur_len, model_kwargs + + # 5. run generation + # 1st generation step has to be run before to initialize `past` + generated, finished_sequences, next_tokens, cur_len, model_kwargs = sample_body_fn( + generated, finished_sequences, input_ids, cur_len, model_kwargs + ) + + # 2-to-n generation steps can then be run in autoregressive fashion + # only in case 1st generation step does NOT yield EOS token though + if sample_cond_fn(generated, finished_sequences, next_tokens, cur_len, model_kwargs): + maximum_iterations = max_length - cur_len - 1 + generated, _, _, cur_len, _ = tf.while_loop( + sample_cond_fn, + sample_body_fn, + (generated, finished_sequences, next_tokens, cur_len, model_kwargs), + maximum_iterations=maximum_iterations, + ) + + # 6. prepare outputs + output_ids = tf.transpose(tf.reshape(generated.concat(), (-1, batch_size))) + + if not use_xla: + # cut for backward compatibility + output_ids = output_ids[:, :cur_len] if return_dict_in_generate: if self.config.is_encoder_decoder: + # if model is an encoder-decoder, retrieve encoder attention weights + # and hidden states + encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None + encoder_hidden_states = ( + model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None + ) + + scores = tuple(scores) if scores is not None else None + decoder_attentions = tuple(decoder_attentions) if decoder_attentions is not None else None + cross_attentions = tuple(cross_attentions) if cross_attentions is not None else None + decoder_hidden_states = tuple(decoder_hidden_states) if decoder_hidden_states is not None else None + return TFSampleEncoderDecoderOutput( - sequences=input_ids, + sequences=output_ids, scores=scores, encoder_attentions=encoder_attentions, encoder_hidden_states=encoder_hidden_states, @@ -2324,13 +2394,13 @@ class TFGenerationMixin: ) else: return TFSampleDecoderOnlyOutput( - sequences=input_ids, + sequences=output_ids, scores=scores, attentions=decoder_attentions, hidden_states=decoder_hidden_states, ) else: - return input_ids + return output_ids def beam_search( self, @@ -2575,8 +2645,8 @@ class TFGenerationMixin: sequences, scores, is_sent_finished, - model_kwargs, input_ids_length, + model_kwargs, ): """ Beam Search termination condition function -- halts the generation loop if any of these conditions becomes @@ -2604,8 +2674,8 @@ class TFGenerationMixin: sequences, scores, is_sent_finished, - model_kwargs, input_ids_length, + model_kwargs, intermediary_running_sequences=None, ): """ @@ -2781,8 +2851,8 @@ class TFGenerationMixin: next_sequences, next_scores, next_is_sent_finished, - next_model_kwargs, next_input_ids_length, + next_model_kwargs, ) # 5. run generation @@ -2799,8 +2869,8 @@ class TFGenerationMixin: sequences, scores, is_sent_finished, - model_kwargs, input_ids_length, + model_kwargs, ) = beam_search_body_fn( cur_len, running_sequences, @@ -2808,8 +2878,8 @@ class TFGenerationMixin: sequences, scores, is_sent_finished, - model_kwargs, input_ids_length, + model_kwargs, ) # 2-to-n generation steps can then be run in autoregressive fashion (only in case 1st generation step does @@ -2821,8 +2891,8 @@ class TFGenerationMixin: sequences, scores, is_sent_finished, - model_kwargs, input_ids_length, + model_kwargs, ): maximum_iterations = max_length - cur_len cur_len, running_sequences, running_scores, sequences, scores, is_sent_finished, _, _ = tf.while_loop( @@ -2835,8 +2905,8 @@ class TFGenerationMixin: sequences, scores, is_sent_finished, - model_kwargs, input_ids_length, + model_kwargs, ), maximum_iterations=maximum_iterations, ) diff --git a/tests/gpt2/test_modeling_tf_gpt2.py b/tests/gpt2/test_modeling_tf_gpt2.py index 5b78bdf1bd..14429811fa 100644 --- a/tests/gpt2/test_modeling_tf_gpt2.py +++ b/tests/gpt2/test_modeling_tf_gpt2.py @@ -447,19 +447,6 @@ class TFGPT2ModelTest(TFModelTesterMixin, TFCoreModelTesterMixin, unittest.TestC @require_tf class TFGPT2ModelLanguageGenerationTest(unittest.TestCase): - @slow - def test_lm_generate_distilgpt2(self): - model = TFGPT2LMHeadModel.from_pretrained("distilgpt2") - input_ids = tf.convert_to_tensor([[464, 1893]], dtype=tf.int32) # The president - - # The president of the United States, and the president of the United Kingdom, have been in the White - # fmt: off - expected_output_ids = [464, 1893, 286, 262, 1578, 1829, 11, 290, 262, 1893, 286, 262, 1578, 7526, 11, 423, 587, 287, 262, 2635] - # fmt: on - - output_ids = model.generate(input_ids, do_sample=False) - self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids) - @slow def test_lm_generate_greedy_distilgpt2_batch_special(self): model = TFGPT2LMHeadModel.from_pretrained("distilgpt2") @@ -506,18 +493,18 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase): "temperature": 1.5, "top_k": 500, "top_p": 0.9, + "seed": [42, 0], # seed set -> deterministic sampling sequence -> deterministic generation } # forces the generation to happen on CPU, to avoid GPU-related quirks with tf.device(":/CPU:0"): - tf.random.set_seed(42) # deterministic sampling sequence -> deterministic generation output_ids = model.generate(input_ids, **generation_kwargs) output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True) expected_output_string = [ - "Today is a beautiful day and this makes finding holiday travel easier for you to do other project\nOh", - "Yesterday was an enjoyable but especially great note though it certainly upset many Democrats who say", + "Today is a beautiful day and we will make you feel very hot/terrific in all", + "Yesterday was another solid success as news coverage became standard American domestic television hit.", ] self.assertListEqual(output_strings, expected_output_string) @@ -561,7 +548,7 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase): self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids) @slow - def test_lm_generate_gpt2_xla(self): + def test_lm_generate_gpt2_xla_greedy(self): """This test gives the exact same results as the non-xla test above""" model = TFGPT2LMHeadModel.from_pretrained("gpt2") input_ids = tf.convert_to_tensor([[464, 3290]], dtype=tf.int32) # The dog @@ -574,3 +561,16 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase): output_ids = xla_generate(input_ids, do_sample=False) self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids) + + @slow + def test_lm_generate_gpt2_xla_sample(self): + model = TFGPT2LMHeadModel.from_pretrained("gpt2") + input_ids = tf.convert_to_tensor([[464, 3290]], dtype=tf.int32) # The dog + + # fmt: off + expected_output_ids = [464, 3290, 550, 284, 307, 4376, 287, 281, 4044, 1363, 329, 734, 812, 878, 852, 4376, 757, 329, 2267, 0] + # fmt: on + xla_generate = tf.function(model.generate, jit_compile=True) + + output_ids = xla_generate(input_ids, do_sample=True, seed=[42, 0]) + self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids) diff --git a/tests/t5/test_modeling_tf_t5.py b/tests/t5/test_modeling_tf_t5.py index 7445aae530..4f89c8ce8d 100644 --- a/tests/t5/test_modeling_tf_t5.py +++ b/tests/t5/test_modeling_tf_t5.py @@ -524,6 +524,35 @@ class TFT5GenerationIntegrationTests(unittest.TestCase): self.assertListEqual(expected_output_string, output_strings) + @slow + def test_sample_xla_generate_simple(self): + model = TFT5ForConditionalGeneration.from_pretrained("t5-small") + tokenizer = T5Tokenizer.from_pretrained("t5-small") + + sentence = "Translate English to German: Today is a beautiful day." + input_ids = tokenizer(sentence, return_tensors="tf", padding=True).input_ids + # XLA reorder ops, which causes operations like FP matmul to have slightly different results, causing + # divergences in generate -- especially with sampling. + expected_output_string = ["Heute ist ein schöner Tag."] + expected_output_string_xla = ["Heute ist ein schöne Tage."] + # However, notice that the first tokens are the same, for the same seed + assert expected_output_string[0][:15] == expected_output_string_xla[0][:15] + + # forces the generation to happen on CPU, to avoid GPU-related quirks + with tf.device(":/CPU:0"): + # seed set -> deterministic sampling sequence -> deterministic generation + output_ids = model.generate(input_ids, do_sample=True, seed=[42, 0]) + output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True) + self.assertListEqual(expected_output_string, output_strings) + + # forces the generation to happen on CPU, to avoid GPU-related quirks + with tf.device(":/CPU:0"): + xla_generate = tf.function(model.generate, jit_compile=True) + # seed set -> deterministic sampling sequence -> deterministic generation + output_ids_xla = xla_generate(input_ids, do_sample=True, seed=[42, 0]) + output_strings_xla = tokenizer.batch_decode(output_ids_xla, skip_special_tokens=True) + self.assertListEqual(expected_output_string_xla, output_strings_xla) + @slow def test_sample_generate(self): model = TFT5ForConditionalGeneration.from_pretrained("t5-small") @@ -540,16 +569,16 @@ class TFT5GenerationIntegrationTests(unittest.TestCase): "temperature": 0.8, "top_k": 500, "top_p": 0.9, + "seed": [20, 0], # seed set -> deterministic sampling sequence -> deterministic generation } # forces the generation to happen on CPU, to avoid GPU-related quirks with tf.device(":/CPU:0"): - tf.random.set_seed(42) # deterministic sampling sequence -> deterministic generation output_ids = model.generate(input_ids, **generation_kwargs) output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True) - expected_output_string = ["i love her I really love my heart", "die Transformatoren sind wirklich erstaunlich"] + expected_output_string = ["- I really love my way of this.", "die Transformatoren sind wirklich erstaunlich"] self.assertListEqual(expected_output_string, output_strings)