Generate: TF uses GenerationConfig as the basis for .generate() parametrization (#20994)
This commit is contained in:
@@ -1824,18 +1824,18 @@ class TFModelTesterMixin:
|
||||
model.train_on_batch(test_batch, test_batch_labels)
|
||||
|
||||
def _test_xla_generate(self, **generate_kwargs):
|
||||
def _generate_and_check_results(model, config, inputs_dict):
|
||||
def _generate_and_check_results(model, inputs_dict):
|
||||
if "input_ids" in inputs_dict:
|
||||
inputs = inputs_dict["input_ids"]
|
||||
# make sure there are no pad tokens in prompt, which may trigger unwanted behavior
|
||||
if config.pad_token_id is not None:
|
||||
if model.generation_config.pad_token_id is not None:
|
||||
if config.pad_token_id == 0:
|
||||
new_pad_token = config.pad_token_id + 1
|
||||
new_pad_token = model.generation_config.pad_token_id + 1
|
||||
else:
|
||||
new_pad_token = config.pad_token_id - 1
|
||||
new_pad_token = model.generation_config.pad_token_id - 1
|
||||
else:
|
||||
new_pad_token = None
|
||||
inputs = tf.where(inputs != config.pad_token_id, inputs, new_pad_token)
|
||||
inputs = tf.where(inputs != model.generation_config.pad_token_id, inputs, new_pad_token)
|
||||
elif "input_features" in inputs_dict:
|
||||
inputs = inputs_dict["input_features"]
|
||||
else:
|
||||
@@ -1854,10 +1854,10 @@ class TFModelTesterMixin:
|
||||
model = model_class(config)
|
||||
|
||||
if model.supports_xla_generation:
|
||||
_generate_and_check_results(model, config, inputs_dict)
|
||||
_generate_and_check_results(model, inputs_dict)
|
||||
else:
|
||||
with self.assertRaises(ValueError):
|
||||
_generate_and_check_results(model, config, inputs_dict)
|
||||
_generate_and_check_results(model, inputs_dict)
|
||||
|
||||
def test_xla_generate_fast(self):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user