From 19d67bfecb62d49eaa2b9856192c63e673d66773 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Tue, 31 Jan 2023 15:49:34 +0000 Subject: [PATCH] Generate: fix TF XLA tests on models with `max_position_embeddings` or `max_target_positions` (#21389) --- tests/test_modeling_tf_common.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index b1359142bb..eaf8f82e78 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -1865,6 +1865,17 @@ class TFModelTesterMixin: config.eos_token_id = None # Generate until max length config.do_sample = False + # fix config for models with additional sequence-length limiting settings + for var_name in ["max_position_embeddings", "max_target_positions"]: + attr = getattr(config, var_name, None) + if attr is not None and attr < generate_kwargs["max_new_tokens"]: + try: + setattr(config, var_name, generate_kwargs["max_new_tokens"]) + except NotImplementedError: + # xlnet will raise an exception when trying to set + # max_position_embeddings. + pass + model = model_class(config) if model.supports_xla_generation: