Generate: TF can now accept custom logits processors (#21454)
This commit is contained in:
@@ -532,6 +532,7 @@ class TFGenerationMixin:
|
|||||||
self,
|
self,
|
||||||
input_ids: Optional[tf.Tensor] = None,
|
input_ids: Optional[tf.Tensor] = None,
|
||||||
generation_config: Optional[GenerationConfig] = None,
|
generation_config: Optional[GenerationConfig] = None,
|
||||||
|
logits_processor: Optional[TFLogitsProcessorList] = None,
|
||||||
seed=None,
|
seed=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Union[TFGenerateOutput, tf.Tensor]:
|
) -> Union[TFGenerateOutput, tf.Tensor]:
|
||||||
@@ -560,6 +561,10 @@ class TFGenerationMixin:
|
|||||||
priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
|
priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
|
||||||
configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
|
configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
|
||||||
default values, whose documentation should be checked to parameterize generation.
|
default values, whose documentation should be checked to parameterize generation.
|
||||||
|
logits_processor (`LogitsProcessorList`, *optional*):
|
||||||
|
Custom logits processors that complement the default logits processors built from arguments and
|
||||||
|
generation config. If a logit processor is passed that is already created with the arguments or a
|
||||||
|
generation config an error is thrown. This feature is intended for advanced users.
|
||||||
seed (`List[int]`, *optional*):
|
seed (`List[int]`, *optional*):
|
||||||
Random seed to control sampling, containing two integers, used when `do_sample` is `True`. See the
|
Random seed to control sampling, containing two integers, used when `do_sample` is `True`. See the
|
||||||
`seed` argument from stateless functions in `tf.random`.
|
`seed` argument from stateless functions in `tf.random`.
|
||||||
@@ -638,6 +643,8 @@ class TFGenerationMixin:
|
|||||||
model_kwargs["decoder_input_ids"] = tf.cast(model_kwargs["decoder_input_ids"], tf.int32)
|
model_kwargs["decoder_input_ids"] = tf.cast(model_kwargs["decoder_input_ids"], tf.int32)
|
||||||
|
|
||||||
# 3. Set generation parameters if not already defined
|
# 3. Set generation parameters if not already defined
|
||||||
|
logits_processor = logits_processor if logits_processor is not None else TFLogitsProcessorList()
|
||||||
|
|
||||||
if generation_config.pad_token_id is None and generation_config.eos_token_id is not None:
|
if generation_config.pad_token_id is None and generation_config.eos_token_id is not None:
|
||||||
if model_kwargs.get("attention_mask") is None:
|
if model_kwargs.get("attention_mask") is None:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@@ -755,6 +762,7 @@ class TFGenerationMixin:
|
|||||||
logits_processor = self._get_logits_processor(
|
logits_processor = self._get_logits_processor(
|
||||||
generation_config=generation_config,
|
generation_config=generation_config,
|
||||||
input_ids_seq_length=input_ids_seq_length,
|
input_ids_seq_length=input_ids_seq_length,
|
||||||
|
logits_processor=logits_processor,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 10. go into different generation modes
|
# 10. go into different generation modes
|
||||||
@@ -1194,6 +1202,7 @@ class TFGenerationMixin:
|
|||||||
self,
|
self,
|
||||||
generation_config: GenerationConfig,
|
generation_config: GenerationConfig,
|
||||||
input_ids_seq_length: int,
|
input_ids_seq_length: int,
|
||||||
|
logits_processor: Optional[TFLogitsProcessorList],
|
||||||
) -> TFLogitsProcessorList:
|
) -> TFLogitsProcessorList:
|
||||||
"""
|
"""
|
||||||
This class returns a [`TFLogitsProcessorList`] list object that contains all relevant [`TFLogitsProcessor`]
|
This class returns a [`TFLogitsProcessorList`] list object that contains all relevant [`TFLogitsProcessor`]
|
||||||
@@ -1240,8 +1249,31 @@ class TFGenerationMixin:
|
|||||||
)
|
)
|
||||||
if generation_config.forced_decoder_ids is not None:
|
if generation_config.forced_decoder_ids is not None:
|
||||||
processors.append(TFForceTokensLogitsProcessor(generation_config.forced_decoder_ids))
|
processors.append(TFForceTokensLogitsProcessor(generation_config.forced_decoder_ids))
|
||||||
|
|
||||||
|
processors = self._merge_criteria_processor_list(processors, logits_processor)
|
||||||
return processors
|
return processors
|
||||||
|
|
||||||
|
def _merge_criteria_processor_list(
|
||||||
|
self,
|
||||||
|
default_list: TFLogitsProcessorList,
|
||||||
|
custom_list: TFLogitsProcessorList,
|
||||||
|
) -> TFLogitsProcessorList:
|
||||||
|
if len(custom_list) == 0:
|
||||||
|
return default_list
|
||||||
|
for default in default_list:
|
||||||
|
for custom in custom_list:
|
||||||
|
if type(custom) is type(default):
|
||||||
|
object_type = "logits processor"
|
||||||
|
raise ValueError(
|
||||||
|
f"A custom {object_type} of type {type(custom)} with values {custom} has been passed to"
|
||||||
|
f" `generate`, but it has already been created with the values {default}. {default} has been"
|
||||||
|
" created by passing the corresponding arguments to generate or by the model's config default"
|
||||||
|
f" values. If you just want to change the default values of {object_type} consider passing"
|
||||||
|
f" them as arguments to `generate` instead of using a custom {object_type}."
|
||||||
|
)
|
||||||
|
default_list.extend(custom_list)
|
||||||
|
return default_list
|
||||||
|
|
||||||
def greedy_search(
|
def greedy_search(
|
||||||
self,
|
self,
|
||||||
input_ids: tf.Tensor,
|
input_ids: tf.Tensor,
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ import numpy as np
|
|||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from ...configuration_utils import PretrainedConfig
|
from ...configuration_utils import PretrainedConfig
|
||||||
|
from ...generation import TFLogitsProcessorList
|
||||||
from ...modeling_tf_utils import (
|
from ...modeling_tf_utils import (
|
||||||
TFCausalLanguageModelingLoss,
|
TFCausalLanguageModelingLoss,
|
||||||
TFModelInputType,
|
TFModelInputType,
|
||||||
@@ -1002,6 +1003,7 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
|
|||||||
doc_scores=None,
|
doc_scores=None,
|
||||||
n_docs=None,
|
n_docs=None,
|
||||||
generation_config=None,
|
generation_config=None,
|
||||||
|
logits_processor=TFLogitsProcessorList(),
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -1045,6 +1047,10 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
|
|||||||
priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
|
priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
|
||||||
configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
|
configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
|
||||||
default values, whose documentation should be checked to parameterize generation.
|
default values, whose documentation should be checked to parameterize generation.
|
||||||
|
logits_processor (`TFLogitsProcessorList`, *optional*):
|
||||||
|
Custom logits processors that complement the default logits processors built from arguments and a
|
||||||
|
model's config. If a logit processor is passed that is already created with the arguments or a model's
|
||||||
|
config an error is thrown.
|
||||||
kwargs:
|
kwargs:
|
||||||
Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
|
Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
|
||||||
forwarded to the `forward` function of the model.
|
forwarded to the `forward` function of the model.
|
||||||
@@ -1149,6 +1155,7 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
|
|||||||
pre_processor = self._get_logits_processor(
|
pre_processor = self._get_logits_processor(
|
||||||
generation_config=generation_config,
|
generation_config=generation_config,
|
||||||
input_ids_seq_length=tf.shape(decoder_input_ids)[-1],
|
input_ids_seq_length=tf.shape(decoder_input_ids)[-1],
|
||||||
|
logits_processor=logits_processor,
|
||||||
)
|
)
|
||||||
|
|
||||||
if generation_config.num_beams == 1:
|
if generation_config.num_beams == 1:
|
||||||
|
|||||||
@@ -12,6 +12,8 @@ class GenerationIntegrationTestsMixin:
|
|||||||
# To be populated by the child classes
|
# To be populated by the child classes
|
||||||
framework_dependent_parameters = {
|
framework_dependent_parameters = {
|
||||||
"AutoModelForSeq2SeqLM": None,
|
"AutoModelForSeq2SeqLM": None,
|
||||||
|
"LogitsProcessorList": None,
|
||||||
|
"MinLengthLogitsProcessor": None,
|
||||||
"create_tensor_fn": None,
|
"create_tensor_fn": None,
|
||||||
"return_tensors": None,
|
"return_tensors": None,
|
||||||
}
|
}
|
||||||
@@ -39,3 +41,23 @@ class GenerationIntegrationTestsMixin:
|
|||||||
# however, valid model_kwargs are accepted
|
# however, valid model_kwargs are accepted
|
||||||
valid_model_kwargs = {"attention_mask": create_tensor_fn(np.zeros_like(input_ids))}
|
valid_model_kwargs = {"attention_mask": create_tensor_fn(np.zeros_like(input_ids))}
|
||||||
model.generate(input_ids, **valid_model_kwargs)
|
model.generate(input_ids, **valid_model_kwargs)
|
||||||
|
|
||||||
|
def test_custom_logits_processor(self):
|
||||||
|
model_cls = self.framework_dependent_parameters["AutoModelForSeq2SeqLM"]
|
||||||
|
logits_processor_list_cls = self.framework_dependent_parameters["LogitsProcessorList"]
|
||||||
|
min_length_logits_processor_cls = self.framework_dependent_parameters["MinLengthLogitsProcessor"]
|
||||||
|
return_tensors = self.framework_dependent_parameters["return_tensors"]
|
||||||
|
|
||||||
|
bart_tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||||
|
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
||||||
|
bart_model = model_cls.from_pretrained("hf-internal-testing/tiny-random-bart", min_length=1)
|
||||||
|
input_ids = bart_tokenizer(article, return_tensors=return_tensors).input_ids
|
||||||
|
|
||||||
|
logits_processor = logits_processor_list_cls()
|
||||||
|
logits_processor.append(min_length_logits_processor_cls(min_length=10, eos_token_id=0))
|
||||||
|
# it should not be allowed to both define `min_length` via config and `logits_processor` list
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
bart_model.generate(input_ids, logits_processor=logits_processor)
|
||||||
|
|
||||||
|
bart_model.config.min_length = None
|
||||||
|
bart_model.generate(input_ids, logits_processor=logits_processor)
|
||||||
|
|||||||
@@ -25,7 +25,13 @@ from .test_framework_agnostic import GenerationIntegrationTestsMixin
|
|||||||
if is_tf_available():
|
if is_tf_available():
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from transformers import TFAutoModelForCausalLM, TFAutoModelForSeq2SeqLM, tf_top_k_top_p_filtering
|
from transformers import (
|
||||||
|
TFAutoModelForCausalLM,
|
||||||
|
TFAutoModelForSeq2SeqLM,
|
||||||
|
TFLogitsProcessorList,
|
||||||
|
TFMinLengthLogitsProcessor,
|
||||||
|
tf_top_k_top_p_filtering,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@require_tf
|
@require_tf
|
||||||
@@ -132,6 +138,8 @@ class TFGenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTests
|
|||||||
if is_tf_available():
|
if is_tf_available():
|
||||||
framework_dependent_parameters = {
|
framework_dependent_parameters = {
|
||||||
"AutoModelForSeq2SeqLM": TFAutoModelForSeq2SeqLM,
|
"AutoModelForSeq2SeqLM": TFAutoModelForSeq2SeqLM,
|
||||||
|
"LogitsProcessorList": TFLogitsProcessorList,
|
||||||
|
"MinLengthLogitsProcessor": TFMinLengthLogitsProcessor,
|
||||||
"create_tensor_fn": tf.convert_to_tensor,
|
"create_tensor_fn": tf.convert_to_tensor,
|
||||||
"return_tensors": "tf",
|
"return_tensors": "tf",
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1797,12 +1797,15 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
|||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
framework_dependent_parameters = {
|
framework_dependent_parameters = {
|
||||||
"AutoModelForSeq2SeqLM": AutoModelForSeq2SeqLM,
|
"AutoModelForSeq2SeqLM": AutoModelForSeq2SeqLM,
|
||||||
|
"LogitsProcessorList": LogitsProcessorList,
|
||||||
|
"MinLengthLogitsProcessor": MinLengthLogitsProcessor,
|
||||||
"create_tensor_fn": torch.tensor,
|
"create_tensor_fn": torch.tensor,
|
||||||
"return_tensors": "pt",
|
"return_tensors": "pt",
|
||||||
}
|
}
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_diverse_beam_search(self):
|
def test_diverse_beam_search(self):
|
||||||
|
# PT-only test: TF doesn't have a diverse beam search implementation
|
||||||
article = """Justin Timberlake and Jessica Biel, welcome to parenthood.
|
article = """Justin Timberlake and Jessica Biel, welcome to parenthood.
|
||||||
The celebrity couple announced the arrival of their son, Silas Randall Timberlake, in statements to People.
|
The celebrity couple announced the arrival of their son, Silas Randall Timberlake, in statements to People.
|
||||||
"Silas was the middle name of Timberlake's maternal grandfather Bill Bomar, who died in 2012, while Randall is the musician's own middle name, as well as his father's first," People reports.
|
"Silas was the middle name of Timberlake's maternal grandfather Bill Bomar, who died in 2012, while Randall is the musician's own middle name, as well as his father's first," People reports.
|
||||||
@@ -1836,6 +1839,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
|||||||
)
|
)
|
||||||
|
|
||||||
def test_max_length_backward_compat_greedy(self):
|
def test_max_length_backward_compat_greedy(self):
|
||||||
|
# PT-only test: TF doesn't have StoppingCriteria
|
||||||
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
||||||
bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
|
bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||||
bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to(
|
bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to(
|
||||||
@@ -1862,6 +1866,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
|||||||
)
|
)
|
||||||
|
|
||||||
def test_max_length_backward_compat_sample(self):
|
def test_max_length_backward_compat_sample(self):
|
||||||
|
# PT-only test: TF doesn't have StoppingCriteria
|
||||||
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
||||||
bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
|
bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||||
bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to(
|
bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to(
|
||||||
@@ -1888,6 +1893,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
|||||||
)
|
)
|
||||||
|
|
||||||
def test_max_length_backward_compat_beam_search(self):
|
def test_max_length_backward_compat_beam_search(self):
|
||||||
|
# PT-only test: TF doesn't have StoppingCriteria
|
||||||
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
||||||
bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
|
bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||||
bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to(
|
bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to(
|
||||||
@@ -1918,6 +1924,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
|||||||
)
|
)
|
||||||
|
|
||||||
def test_max_length_backward_compat_group_beam_search(self):
|
def test_max_length_backward_compat_group_beam_search(self):
|
||||||
|
# PT-only test: TF doesn't have StoppingCriteria & group beam search
|
||||||
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
||||||
bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
|
bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||||
bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to(
|
bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to(
|
||||||
@@ -1952,6 +1959,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
|||||||
)
|
)
|
||||||
|
|
||||||
def test_max_length_warning_if_different(self):
|
def test_max_length_warning_if_different(self):
|
||||||
|
# PT-only test: TF doesn't have StoppingCriteria
|
||||||
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
||||||
bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
|
bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||||
bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to(
|
bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to(
|
||||||
@@ -2035,6 +2043,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
|||||||
)
|
)
|
||||||
|
|
||||||
def test_custom_stopping_criteria_overload_error(self):
|
def test_custom_stopping_criteria_overload_error(self):
|
||||||
|
# PT-only test: TF doesn't have StoppingCriteria
|
||||||
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
||||||
bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random")
|
bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random")
|
||||||
bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device)
|
bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device)
|
||||||
@@ -2048,6 +2057,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
|||||||
bart_model.generate(input_ids, stopping_criteria=stopping_criteria, max_length=32)
|
bart_model.generate(input_ids, stopping_criteria=stopping_criteria, max_length=32)
|
||||||
|
|
||||||
def test_custom_stopping_criteria(self):
|
def test_custom_stopping_criteria(self):
|
||||||
|
# PT-only test: TF doesn't have StoppingCriteria
|
||||||
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
||||||
bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random")
|
bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random")
|
||||||
bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device)
|
bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device)
|
||||||
@@ -2070,7 +2080,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
|||||||
)
|
)
|
||||||
|
|
||||||
def test_stop_sequence_stopping_criteria(self):
|
def test_stop_sequence_stopping_criteria(self):
|
||||||
|
# PT-only test: TF doesn't have StoppingCriteria
|
||||||
prompt = """Hello I believe in"""
|
prompt = """Hello I believe in"""
|
||||||
generator = pipeline("text-generation", model="hf-internal-testing/tiny-random-bart")
|
generator = pipeline("text-generation", model="hf-internal-testing/tiny-random-bart")
|
||||||
output = generator(prompt)
|
output = generator(prompt)
|
||||||
@@ -2088,23 +2098,6 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
|||||||
output = generator(prompt, stop_sequence=" number")
|
output = generator(prompt, stop_sequence=" number")
|
||||||
self.assertEqual(output, [{"generated_text": "Hello I believe in in in number"}])
|
self.assertEqual(output, [{"generated_text": "Hello I believe in in in number"}])
|
||||||
|
|
||||||
def test_custom_logits_processor(self):
|
|
||||||
bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random")
|
|
||||||
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
|
||||||
bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random", min_length=1).to(
|
|
||||||
torch_device
|
|
||||||
)
|
|
||||||
input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
|
|
||||||
|
|
||||||
logits_processor = LogitsProcessorList()
|
|
||||||
logits_processor.append(MinLengthLogitsProcessor(min_length=10, eos_token_id=0))
|
|
||||||
# it should not be allowed to both define `min_length` via config and `logits_processor` list
|
|
||||||
with self.assertRaises(ValueError):
|
|
||||||
bart_model.generate(input_ids, logits_processor=logits_processor)
|
|
||||||
|
|
||||||
bart_model.config.min_length = None
|
|
||||||
bart_model.generate(input_ids, logits_processor=logits_processor)
|
|
||||||
|
|
||||||
def test_max_new_tokens_encoder_decoder(self):
|
def test_max_new_tokens_encoder_decoder(self):
|
||||||
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
||||||
bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
|
bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||||
|
|||||||
Reference in New Issue
Block a user