Generate: TF contrastive search with XLA support (#20050)
* Add contrastive search
This commit is contained in:
@@ -1783,7 +1783,7 @@ class TFModelTesterMixin:
|
||||
model.compile(optimizer="sgd", run_eagerly=True)
|
||||
model.train_on_batch(test_batch, test_batch_labels)
|
||||
|
||||
def _test_xla_generate(self, num_beams, num_return_sequences, max_length):
|
||||
def _test_xla_generate(self, num_beams, num_return_sequences, max_length, **generate_kwargs):
|
||||
def _generate_and_check_results(model, config, inputs_dict):
|
||||
if "input_ids" in inputs_dict:
|
||||
inputs = inputs_dict["input_ids"]
|
||||
@@ -1801,9 +1801,9 @@ class TFModelTesterMixin:
|
||||
else:
|
||||
raise ValueError("No valid generate input found in inputs_dict")
|
||||
|
||||
generated = model.generate(inputs).numpy()
|
||||
generated = model.generate(inputs, **generate_kwargs).numpy()
|
||||
generate_xla = tf.function(model.generate, jit_compile=True)
|
||||
generated_xla = generate_xla(inputs).numpy()
|
||||
generated_xla = generate_xla(inputs, **generate_kwargs).numpy()
|
||||
self.assertListEqual(generated.tolist(), generated_xla.tolist())
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
@@ -1844,6 +1844,19 @@ class TFModelTesterMixin:
|
||||
max_length = 10
|
||||
self._test_xla_generate(num_beams, num_return_sequences, max_length)
|
||||
|
||||
def test_xla_generate_contrastive(self):
|
||||
"""
|
||||
Similar to `test_xla_generate_fast`, but for contrastive search -- contrastive search directly manipulates the
|
||||
model cache and other outputs, and this test ensures that they are in a valid format that is also supported
|
||||
by XLA.
|
||||
|
||||
Either the model supports XLA generation and passes the inner test, or it raises an appropriate exception
|
||||
"""
|
||||
num_beams = 1
|
||||
num_return_sequences = 1
|
||||
max_length = 10
|
||||
self._test_xla_generate(num_beams, num_return_sequences, max_length, penalty_alpha=0.5, top_k=5)
|
||||
|
||||
@slow
|
||||
def test_xla_generate_slow(self):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user