Generate: TF can now accept custom logits processors (#21454)

This commit is contained in:
Joao Gante
2023-02-06 15:44:47 +00:00
committed by GitHub
parent e215e6ded2
commit 4943331015
5 changed files with 81 additions and 19 deletions

View File

@@ -25,7 +25,13 @@ from .test_framework_agnostic import GenerationIntegrationTestsMixin
if is_tf_available():
import tensorflow as tf
from transformers import TFAutoModelForCausalLM, TFAutoModelForSeq2SeqLM, tf_top_k_top_p_filtering
from transformers import (
TFAutoModelForCausalLM,
TFAutoModelForSeq2SeqLM,
TFLogitsProcessorList,
TFMinLengthLogitsProcessor,
tf_top_k_top_p_filtering,
)
@require_tf
@@ -132,6 +138,8 @@ class TFGenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTests
if is_tf_available():
framework_dependent_parameters = {
"AutoModelForSeq2SeqLM": TFAutoModelForSeq2SeqLM,
"LogitsProcessorList": TFLogitsProcessorList,
"MinLengthLogitsProcessor": TFMinLengthLogitsProcessor,
"create_tensor_fn": tf.convert_to_tensor,
"return_tensors": "tf",
}