Generate: TF contrastive search with XLA support (#20050)

* Add contrastive search
This commit is contained in:
Joao Gante
2022-11-07 10:54:29 +00:00
committed by GitHub
parent 504db92e7d
commit a0f8674303
5 changed files with 770 additions and 46 deletions

View File

@@ -1783,7 +1783,7 @@ class TFModelTesterMixin:
model.compile(optimizer="sgd", run_eagerly=True)
model.train_on_batch(test_batch, test_batch_labels)
def _test_xla_generate(self, num_beams, num_return_sequences, max_length):
def _test_xla_generate(self, num_beams, num_return_sequences, max_length, **generate_kwargs):
def _generate_and_check_results(model, config, inputs_dict):
if "input_ids" in inputs_dict:
inputs = inputs_dict["input_ids"]
@@ -1801,9 +1801,9 @@ class TFModelTesterMixin:
else:
raise ValueError("No valid generate input found in inputs_dict")
generated = model.generate(inputs).numpy()
generated = model.generate(inputs, **generate_kwargs).numpy()
generate_xla = tf.function(model.generate, jit_compile=True)
generated_xla = generate_xla(inputs).numpy()
generated_xla = generate_xla(inputs, **generate_kwargs).numpy()
self.assertListEqual(generated.tolist(), generated_xla.tolist())
for model_class in self.all_generative_model_classes:
@@ -1844,6 +1844,19 @@ class TFModelTesterMixin:
max_length = 10
self._test_xla_generate(num_beams, num_return_sequences, max_length)
def test_xla_generate_contrastive(self):
"""
Similar to `test_xla_generate_fast`, but for contrastive search -- contrastive search directly manipulates the
model cache and other outputs, and this test ensures that they are in a valid format that is also supported
by XLA.
Either the model supports XLA generation and passes the inner test, or it raises an appropriate exception
"""
num_beams = 1
num_return_sequences = 1
max_length = 10
self._test_xla_generate(num_beams, num_return_sequences, max_length, penalty_alpha=0.5, top_k=5)
@slow
def test_xla_generate_slow(self):
"""