Generate: general TF XLA constrastive search are now slow tests (#20277)

* move contrastive search test to slow
This commit is contained in:
Joao Gante
2022-11-17 12:34:46 +00:00
committed by GitHub
parent 2062c28552
commit 0f78529f98

View File

@@ -1800,7 +1800,7 @@ class TFModelTesterMixin:
model.compile(optimizer="sgd", run_eagerly=True) model.compile(optimizer="sgd", run_eagerly=True)
model.train_on_batch(test_batch, test_batch_labels) model.train_on_batch(test_batch, test_batch_labels)
def _test_xla_generate(self, num_beams, num_return_sequences, max_length, **generate_kwargs): def _test_xla_generate(self, **generate_kwargs):
def _generate_and_check_results(model, config, inputs_dict): def _generate_and_check_results(model, config, inputs_dict):
if "input_ids" in inputs_dict: if "input_ids" in inputs_dict:
inputs = inputs_dict["input_ids"] inputs = inputs_dict["input_ids"]
@@ -1826,20 +1826,7 @@ class TFModelTesterMixin:
for model_class in self.all_generative_model_classes: for model_class in self.all_generative_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.eos_token_id = None # Generate until max length config.eos_token_id = None # Generate until max length
config.max_length = max_length
config.do_sample = False config.do_sample = False
config.num_beams = num_beams
config.num_return_sequences = num_return_sequences
# fix config for models with additional sequence-length limiting settings
for var_name in ["max_position_embeddings", "max_target_positions"]:
if hasattr(config, var_name):
try:
setattr(config, var_name, max_length)
except NotImplementedError:
# xlnet will raise an exception when trying to set
# max_position_embeddings.
pass
model = model_class(config) model = model_class(config)
@@ -1856,23 +1843,18 @@ class TFModelTesterMixin:
Either the model supports XLA generation and passes the inner test, or it raises an appropriate exception Either the model supports XLA generation and passes the inner test, or it raises an appropriate exception
""" """
num_beams = 1 self._test_xla_generate(num_beams=1, num_return_sequences=1, max_new_tokens=3)
num_return_sequences = 1
max_length = 10
self._test_xla_generate(num_beams, num_return_sequences, max_length)
@slow
def test_xla_generate_contrastive(self): def test_xla_generate_contrastive(self):
""" """
Similar to `test_xla_generate_fast`, but for contrastive search -- contrastive search directly manipulates the Slow and challenging version of `test_xla_generate_fast` for contrastive search -- contrastive search directly
model cache and other outputs, and this test ensures that they are in a valid format that is also supported manipulates the model cache and other outputs, and this test ensures that they are in a valid format that is
by XLA. also supported by XLA.
Either the model supports XLA generation and passes the inner test, or it raises an appropriate exception Either the model supports XLA generation and passes the inner test, or it raises an appropriate exception
""" """
num_beams = 1 self._test_xla_generate(num_beams=1, num_return_sequences=1, max_new_tokens=64, penalty_alpha=0.5, top_k=4)
num_return_sequences = 1
max_length = 10
self._test_xla_generate(num_beams, num_return_sequences, max_length, penalty_alpha=0.5, top_k=5)
@slow @slow
def test_xla_generate_slow(self): def test_xla_generate_slow(self):
@@ -1883,10 +1865,7 @@ class TFModelTesterMixin:
Either the model supports XLA generation and passes the inner test, or it raises an appropriate exception Either the model supports XLA generation and passes the inner test, or it raises an appropriate exception
""" """
num_beams = 8 self._test_xla_generate(num_beams=8, num_return_sequences=2, max_new_tokens=128)
num_return_sequences = 2
max_length = 128
self._test_xla_generate(num_beams, num_return_sequences, max_length)
def _generate_random_bad_tokens(self, num_bad_tokens, model): def _generate_random_bad_tokens(self, num_bad_tokens, model):
# special tokens cannot be bad tokens # special tokens cannot be bad tokens