Generate: general TF XLA constrastive search are now slow tests (#20277)
* move contrastive search test to slow
This commit is contained in:
@@ -1800,7 +1800,7 @@ class TFModelTesterMixin:
|
|||||||
model.compile(optimizer="sgd", run_eagerly=True)
|
model.compile(optimizer="sgd", run_eagerly=True)
|
||||||
model.train_on_batch(test_batch, test_batch_labels)
|
model.train_on_batch(test_batch, test_batch_labels)
|
||||||
|
|
||||||
def _test_xla_generate(self, num_beams, num_return_sequences, max_length, **generate_kwargs):
|
def _test_xla_generate(self, **generate_kwargs):
|
||||||
def _generate_and_check_results(model, config, inputs_dict):
|
def _generate_and_check_results(model, config, inputs_dict):
|
||||||
if "input_ids" in inputs_dict:
|
if "input_ids" in inputs_dict:
|
||||||
inputs = inputs_dict["input_ids"]
|
inputs = inputs_dict["input_ids"]
|
||||||
@@ -1826,20 +1826,7 @@ class TFModelTesterMixin:
|
|||||||
for model_class in self.all_generative_model_classes:
|
for model_class in self.all_generative_model_classes:
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
config.eos_token_id = None # Generate until max length
|
config.eos_token_id = None # Generate until max length
|
||||||
config.max_length = max_length
|
|
||||||
config.do_sample = False
|
config.do_sample = False
|
||||||
config.num_beams = num_beams
|
|
||||||
config.num_return_sequences = num_return_sequences
|
|
||||||
|
|
||||||
# fix config for models with additional sequence-length limiting settings
|
|
||||||
for var_name in ["max_position_embeddings", "max_target_positions"]:
|
|
||||||
if hasattr(config, var_name):
|
|
||||||
try:
|
|
||||||
setattr(config, var_name, max_length)
|
|
||||||
except NotImplementedError:
|
|
||||||
# xlnet will raise an exception when trying to set
|
|
||||||
# max_position_embeddings.
|
|
||||||
pass
|
|
||||||
|
|
||||||
model = model_class(config)
|
model = model_class(config)
|
||||||
|
|
||||||
@@ -1856,23 +1843,18 @@ class TFModelTesterMixin:
|
|||||||
|
|
||||||
Either the model supports XLA generation and passes the inner test, or it raises an appropriate exception
|
Either the model supports XLA generation and passes the inner test, or it raises an appropriate exception
|
||||||
"""
|
"""
|
||||||
num_beams = 1
|
self._test_xla_generate(num_beams=1, num_return_sequences=1, max_new_tokens=3)
|
||||||
num_return_sequences = 1
|
|
||||||
max_length = 10
|
|
||||||
self._test_xla_generate(num_beams, num_return_sequences, max_length)
|
|
||||||
|
|
||||||
|
@slow
|
||||||
def test_xla_generate_contrastive(self):
|
def test_xla_generate_contrastive(self):
|
||||||
"""
|
"""
|
||||||
Similar to `test_xla_generate_fast`, but for contrastive search -- contrastive search directly manipulates the
|
Slow and challenging version of `test_xla_generate_fast` for contrastive search -- contrastive search directly
|
||||||
model cache and other outputs, and this test ensures that they are in a valid format that is also supported
|
manipulates the model cache and other outputs, and this test ensures that they are in a valid format that is
|
||||||
by XLA.
|
also supported by XLA.
|
||||||
|
|
||||||
Either the model supports XLA generation and passes the inner test, or it raises an appropriate exception
|
Either the model supports XLA generation and passes the inner test, or it raises an appropriate exception
|
||||||
"""
|
"""
|
||||||
num_beams = 1
|
self._test_xla_generate(num_beams=1, num_return_sequences=1, max_new_tokens=64, penalty_alpha=0.5, top_k=4)
|
||||||
num_return_sequences = 1
|
|
||||||
max_length = 10
|
|
||||||
self._test_xla_generate(num_beams, num_return_sequences, max_length, penalty_alpha=0.5, top_k=5)
|
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_xla_generate_slow(self):
|
def test_xla_generate_slow(self):
|
||||||
@@ -1883,10 +1865,7 @@ class TFModelTesterMixin:
|
|||||||
|
|
||||||
Either the model supports XLA generation and passes the inner test, or it raises an appropriate exception
|
Either the model supports XLA generation and passes the inner test, or it raises an appropriate exception
|
||||||
"""
|
"""
|
||||||
num_beams = 8
|
self._test_xla_generate(num_beams=8, num_return_sequences=2, max_new_tokens=128)
|
||||||
num_return_sequences = 2
|
|
||||||
max_length = 128
|
|
||||||
self._test_xla_generate(num_beams, num_return_sequences, max_length)
|
|
||||||
|
|
||||||
def _generate_random_bad_tokens(self, num_bad_tokens, model):
|
def _generate_random_bad_tokens(self, num_bad_tokens, model):
|
||||||
# special tokens cannot be bad tokens
|
# special tokens cannot be bad tokens
|
||||||
|
|||||||
Reference in New Issue
Block a user