From 5ef432e4742cc505f610f8e54ac1cd2e1dfd265e Mon Sep 17 00:00:00 2001 From: Anton Vlasjuk <73884904+vasqu@users.noreply.github.com> Date: Sat, 5 Oct 2024 16:20:50 +0200 Subject: [PATCH] [`TF`] Fix Tensorflow XLA Generation on limited seq_len models (#33903) * fix tf xla generation on limited seq_len models * [run-slow] opt * [run-slow] opt --- tests/test_modeling_tf_common.py | 27 +++++++++++++++------------ 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index 2cf272f4aa..eb328d83e9 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -1715,10 +1715,9 @@ class TFModelTesterMixin: model.train_on_batch(test_batch, test_batch_labels) def _test_xla_generate(self, **generate_kwargs): - def _generate_and_check_results(model, inputs_dict): - if "input_ids" in inputs_dict: - inputs = inputs_dict["input_ids"] - # make sure there are no pad tokens in prompt, which may trigger unwanted behavior + def _generate_and_check_results(model, inputs, is_input_ids): + # make sure there are no pad tokens in prompt, which may trigger unwanted behavior + if is_input_ids: if model.generation_config.pad_token_id is not None: if config.pad_token_id == 0: new_pad_token = model.generation_config.pad_token_id + 1 @@ -1727,10 +1726,6 @@ class TFModelTesterMixin: else: new_pad_token = None inputs = tf.where(inputs != model.generation_config.pad_token_id, inputs, new_pad_token) - elif "input_features" in inputs_dict: - inputs = inputs_dict["input_features"] - else: - raise ValueError("No valid generate input found in inputs_dict") generated = model.generate(inputs, **generate_kwargs).numpy() generate_xla = tf.function(model.generate, jit_compile=True) @@ -1753,12 +1748,20 @@ class TFModelTesterMixin: config.eos_token_id = None # Generate until max length config.do_sample = False + # extract the input to the model + is_input_ids = "input_ids" in inputs_dict + is_input_features = "input_features" in inputs_dict + if not (is_input_ids or is_input_features): + raise ValueError("No valid generate input found in inputs_dict") + inputs = inputs_dict["input_ids"] if is_input_ids else inputs_dict["input_features"] + # fix config for models with additional sequence-length limiting settings + seq_len = inputs.get_shape()[1] for var_name in ["max_position_embeddings", "max_target_positions"]: attr = getattr(config, var_name, None) - if attr is not None and attr < generate_kwargs["max_new_tokens"]: + if attr is not None and attr < seq_len + generate_kwargs["max_new_tokens"]: try: - setattr(config, var_name, generate_kwargs["max_new_tokens"]) + setattr(config, var_name, seq_len + generate_kwargs["max_new_tokens"]) except NotImplementedError: # xlnet will raise an exception when trying to set # max_position_embeddings. @@ -1767,10 +1770,10 @@ class TFModelTesterMixin: model = model_class(config) if model.supports_xla_generation: - _generate_and_check_results(model, inputs_dict) + _generate_and_check_results(model, inputs, is_input_ids) else: with self.assertRaises(ValueError): - _generate_and_check_results(model, inputs_dict) + _generate_and_check_results(model, inputs, is_input_ids) def test_xla_generate_fast(self): """