Generate: TF uses GenerationConfig as the basis for .generate() parametrization (#20994)

This commit is contained in:
Joao Gante
2023-01-04 18:23:20 +00:00
committed by GitHub
parent 3b309818e7
commit a6c850e4f4
4 changed files with 440 additions and 574 deletions

View File

@@ -1824,18 +1824,18 @@ class TFModelTesterMixin:
model.train_on_batch(test_batch, test_batch_labels)
def _test_xla_generate(self, **generate_kwargs):
def _generate_and_check_results(model, config, inputs_dict):
def _generate_and_check_results(model, inputs_dict):
if "input_ids" in inputs_dict:
inputs = inputs_dict["input_ids"]
# make sure there are no pad tokens in prompt, which may trigger unwanted behavior
if config.pad_token_id is not None:
if model.generation_config.pad_token_id is not None:
if config.pad_token_id == 0:
new_pad_token = config.pad_token_id + 1
new_pad_token = model.generation_config.pad_token_id + 1
else:
new_pad_token = config.pad_token_id - 1
new_pad_token = model.generation_config.pad_token_id - 1
else:
new_pad_token = None
inputs = tf.where(inputs != config.pad_token_id, inputs, new_pad_token)
inputs = tf.where(inputs != model.generation_config.pad_token_id, inputs, new_pad_token)
elif "input_features" in inputs_dict:
inputs = inputs_dict["input_features"]
else:
@@ -1854,10 +1854,10 @@ class TFModelTesterMixin:
model = model_class(config)
if model.supports_xla_generation:
_generate_and_check_results(model, config, inputs_dict)
_generate_and_check_results(model, inputs_dict)
else:
with self.assertRaises(ValueError):
_generate_and_check_results(model, config, inputs_dict)
_generate_and_check_results(model, inputs_dict)
def test_xla_generate_fast(self):
"""