Generate: TF can now accept custom logits processors (#21454)
This commit is contained in:
@@ -25,7 +25,13 @@ from .test_framework_agnostic import GenerationIntegrationTestsMixin
|
||||
if is_tf_available():
|
||||
import tensorflow as tf
|
||||
|
||||
from transformers import TFAutoModelForCausalLM, TFAutoModelForSeq2SeqLM, tf_top_k_top_p_filtering
|
||||
from transformers import (
|
||||
TFAutoModelForCausalLM,
|
||||
TFAutoModelForSeq2SeqLM,
|
||||
TFLogitsProcessorList,
|
||||
TFMinLengthLogitsProcessor,
|
||||
tf_top_k_top_p_filtering,
|
||||
)
|
||||
|
||||
|
||||
@require_tf
|
||||
@@ -132,6 +138,8 @@ class TFGenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTests
|
||||
if is_tf_available():
|
||||
framework_dependent_parameters = {
|
||||
"AutoModelForSeq2SeqLM": TFAutoModelForSeq2SeqLM,
|
||||
"LogitsProcessorList": TFLogitsProcessorList,
|
||||
"MinLengthLogitsProcessor": TFMinLengthLogitsProcessor,
|
||||
"create_tensor_fn": tf.convert_to_tensor,
|
||||
"return_tensors": "tf",
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user