From 2edf9a857be94677cc14870bdb646ddbfdf2a137 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 9 Feb 2023 12:52:30 +0000 Subject: [PATCH] Generate: TF `.generate()` can now be exported with dynamic length (#21474) --- src/transformers/generation/tf_utils.py | 31 ++++++----- .../models/gpt2/modeling_tf_gpt2.py | 2 +- tests/generation/test_tf_utils.py | 55 +++++++++++++++++-- 3 files changed, 68 insertions(+), 20 deletions(-) diff --git a/src/transformers/generation/tf_utils.py b/src/transformers/generation/tf_utils.py index f1d8369eb9..3e9b5677b0 100644 --- a/src/transformers/generation/tf_utils.py +++ b/src/transformers/generation/tf_utils.py @@ -849,7 +849,7 @@ class TFGenerationMixin: input_ids = inputs_tensor # 7. Prepare `max_length` depending on other stopping criteria. - input_ids_seq_length = input_ids.shape[-1] + input_ids_seq_length = shape_list(input_ids)[-1] has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None if has_default_max_length and generation_config.max_new_tokens is None: warnings.warn( @@ -869,18 +869,23 @@ class TFGenerationMixin: UserWarning, ) - if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length: - raise ValueError( - f"Unfeasable length constraints: the minimum length ({generation_config.min_length}) is larger than" - f" the maximum length ({generation_config.max_length})" - ) - if input_ids_seq_length >= generation_config.max_length: - input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" - logger.warning( - f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to" - f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" - " increasing`max_new_tokens`." - ) + # If the input length is a tensor (i.e. dynamic length), skip length checks + if not isinstance(input_ids_seq_length, tf.Tensor): + if ( + generation_config.min_length is not None + and generation_config.min_length > generation_config.max_length + ): + raise ValueError( + f"Unfeasable length constraints: the minimum length ({generation_config.min_length}) is larger" + f" than the maximum length ({generation_config.max_length})" + ) + if input_ids_seq_length >= generation_config.max_length: + input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" + logger.warning( + f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to" + f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" + " increasing`max_new_tokens`." + ) # 8. determine generation mode is_contrastive_search_gen_mode = ( diff --git a/src/transformers/models/gpt2/modeling_tf_gpt2.py b/src/transformers/models/gpt2/modeling_tf_gpt2.py index 62df70cce4..a80b2d4d33 100644 --- a/src/transformers/models/gpt2/modeling_tf_gpt2.py +++ b/src/transformers/models/gpt2/modeling_tf_gpt2.py @@ -182,7 +182,7 @@ class TFAttention(tf.keras.layers.Layer): key = self.split_heads(key) value = self.split_heads(value) if layer_past is not None: - past_key, past_value = tf.unstack(layer_past, axis=0) + past_key, past_value = tf.unstack(layer_past, axis=0, num=2) key = tf.concat([past_key, key], axis=-2) value = tf.concat([past_value, value], axis=-2) diff --git a/tests/generation/test_tf_utils.py b/tests/generation/test_tf_utils.py index 2e6e9c6246..d6a4d5280a 100644 --- a/tests/generation/test_tf_utils.py +++ b/tests/generation/test_tf_utils.py @@ -144,9 +144,10 @@ class TFGenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTests } @slow - def test_generate_tf_function_export(self): + def test_generate_tf_function_export_fixed_input_length(self): test_model = TFAutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2") - max_length = 2 + input_length = 2 + max_new_tokens = 2 class DummyModel(tf.Module): def __init__(self, model): @@ -155,8 +156,8 @@ class TFGenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTests @tf.function( input_signature=( - tf.TensorSpec((None, max_length), tf.int32, name="input_ids"), - tf.TensorSpec((None, max_length), tf.int32, name="attention_mask"), + tf.TensorSpec((None, input_length), tf.int32, name="input_ids"), + tf.TensorSpec((None, input_length), tf.int32, name="attention_mask"), ), jit_compile=True, ) @@ -164,7 +165,7 @@ class TFGenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTests outputs = self.model.generate( input_ids=input_ids, attention_mask=attention_mask, - max_new_tokens=max_length, + max_new_tokens=max_new_tokens, return_dict_in_generate=True, ) return {"sequences": outputs["sequences"]} @@ -181,5 +182,47 @@ class TFGenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTests "attention_mask": tf.constant(dummy_attention_masks[:batch_size]), } tf_func_outputs = serving_func(**inputs)["sequences"] - tf_model_outputs = test_model.generate(**inputs, max_new_tokens=max_length) + tf_model_outputs = test_model.generate(**inputs, max_new_tokens=max_new_tokens) + tf.debugging.assert_equal(tf_func_outputs, tf_model_outputs) + + @slow + def test_generate_tf_function_export_fixed_batch_size(self): + test_model = TFAutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2") + batch_size = 1 + max_new_tokens = 2 + + class DummyModel(tf.Module): + def __init__(self, model): + super(DummyModel, self).__init__() + self.model = model + + @tf.function( + input_signature=( + tf.TensorSpec((batch_size, None), tf.int32, name="input_ids"), + tf.TensorSpec((batch_size, None), tf.int32, name="attention_mask"), + ), + jit_compile=True, + ) + def serving(self, input_ids, attention_mask): + outputs = self.model.generate( + input_ids=input_ids, + attention_mask=attention_mask, + max_new_tokens=max_new_tokens, + return_dict_in_generate=True, + ) + return {"sequences": outputs["sequences"]} + + dummy_input_ids = [[2], [102, 103]] + dummy_attention_masks = [[1], [1, 1]] + dummy_model = DummyModel(model=test_model) + with tempfile.TemporaryDirectory() as tmp_dir: + tf.saved_model.save(dummy_model, tmp_dir, signatures={"serving_default": dummy_model.serving}) + serving_func = tf.saved_model.load(tmp_dir).signatures["serving_default"] + for input_row in range(len(dummy_input_ids)): + inputs = { + "input_ids": tf.constant([dummy_input_ids[input_row]]), + "attention_mask": tf.constant([dummy_attention_masks[input_row]]), + } + tf_func_outputs = serving_func(**inputs)["sequences"] + tf_model_outputs = test_model.generate(**inputs, max_new_tokens=max_new_tokens) tf.debugging.assert_equal(tf_func_outputs, tf_model_outputs)