From 0f78529f982eceb79c5855d0466c287ec8a18df1 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 17 Nov 2022 12:34:46 +0000 Subject: [PATCH] Generate: general TF XLA constrastive search are now slow tests (#20277) * move contrastive search test to slow --- tests/test_modeling_tf_common.py | 37 +++++++------------------------- 1 file changed, 8 insertions(+), 29 deletions(-) diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index 66728e095a..c888ec7c8c 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -1800,7 +1800,7 @@ class TFModelTesterMixin: model.compile(optimizer="sgd", run_eagerly=True) model.train_on_batch(test_batch, test_batch_labels) - def _test_xla_generate(self, num_beams, num_return_sequences, max_length, **generate_kwargs): + def _test_xla_generate(self, **generate_kwargs): def _generate_and_check_results(model, config, inputs_dict): if "input_ids" in inputs_dict: inputs = inputs_dict["input_ids"] @@ -1826,20 +1826,7 @@ class TFModelTesterMixin: for model_class in self.all_generative_model_classes: config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config.eos_token_id = None # Generate until max length - config.max_length = max_length config.do_sample = False - config.num_beams = num_beams - config.num_return_sequences = num_return_sequences - - # fix config for models with additional sequence-length limiting settings - for var_name in ["max_position_embeddings", "max_target_positions"]: - if hasattr(config, var_name): - try: - setattr(config, var_name, max_length) - except NotImplementedError: - # xlnet will raise an exception when trying to set - # max_position_embeddings. - pass model = model_class(config) @@ -1856,23 +1843,18 @@ class TFModelTesterMixin: Either the model supports XLA generation and passes the inner test, or it raises an appropriate exception """ - num_beams = 1 - num_return_sequences = 1 - max_length = 10 - self._test_xla_generate(num_beams, num_return_sequences, max_length) + self._test_xla_generate(num_beams=1, num_return_sequences=1, max_new_tokens=3) + @slow def test_xla_generate_contrastive(self): """ - Similar to `test_xla_generate_fast`, but for contrastive search -- contrastive search directly manipulates the - model cache and other outputs, and this test ensures that they are in a valid format that is also supported - by XLA. + Slow and challenging version of `test_xla_generate_fast` for contrastive search -- contrastive search directly + manipulates the model cache and other outputs, and this test ensures that they are in a valid format that is + also supported by XLA. Either the model supports XLA generation and passes the inner test, or it raises an appropriate exception """ - num_beams = 1 - num_return_sequences = 1 - max_length = 10 - self._test_xla_generate(num_beams, num_return_sequences, max_length, penalty_alpha=0.5, top_k=5) + self._test_xla_generate(num_beams=1, num_return_sequences=1, max_new_tokens=64, penalty_alpha=0.5, top_k=4) @slow def test_xla_generate_slow(self): @@ -1883,10 +1865,7 @@ class TFModelTesterMixin: Either the model supports XLA generation and passes the inner test, or it raises an appropriate exception """ - num_beams = 8 - num_return_sequences = 2 - max_length = 128 - self._test_xla_generate(num_beams, num_return_sequences, max_length) + self._test_xla_generate(num_beams=8, num_return_sequences=2, max_new_tokens=128) def _generate_random_bad_tokens(self, num_bad_tokens, model): # special tokens cannot be bad tokens