Generate: TF can now accept custom logits processors (#21454)

This commit is contained in:
Joao Gante
2023-02-06 15:44:47 +00:00
committed by GitHub
parent e215e6ded2
commit 4943331015
5 changed files with 81 additions and 19 deletions

View File

@@ -12,6 +12,8 @@ class GenerationIntegrationTestsMixin:
# To be populated by the child classes
framework_dependent_parameters = {
"AutoModelForSeq2SeqLM": None,
"LogitsProcessorList": None,
"MinLengthLogitsProcessor": None,
"create_tensor_fn": None,
"return_tensors": None,
}
@@ -39,3 +41,23 @@ class GenerationIntegrationTestsMixin:
# however, valid model_kwargs are accepted
valid_model_kwargs = {"attention_mask": create_tensor_fn(np.zeros_like(input_ids))}
model.generate(input_ids, **valid_model_kwargs)
def test_custom_logits_processor(self):
model_cls = self.framework_dependent_parameters["AutoModelForSeq2SeqLM"]
logits_processor_list_cls = self.framework_dependent_parameters["LogitsProcessorList"]
min_length_logits_processor_cls = self.framework_dependent_parameters["MinLengthLogitsProcessor"]
return_tensors = self.framework_dependent_parameters["return_tensors"]
bart_tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
bart_model = model_cls.from_pretrained("hf-internal-testing/tiny-random-bart", min_length=1)
input_ids = bart_tokenizer(article, return_tensors=return_tensors).input_ids
logits_processor = logits_processor_list_cls()
logits_processor.append(min_length_logits_processor_cls(min_length=10, eos_token_id=0))
# it should not be allowed to both define `min_length` via config and `logits_processor` list
with self.assertRaises(ValueError):
bart_model.generate(input_ids, logits_processor=logits_processor)
bart_model.config.min_length = None
bart_model.generate(input_ids, logits_processor=logits_processor)