remove adjust_logits_during_generation method (#10087)
* add forced logits processors * delete adjust_logits method * add forced_eos_token_id argument in config * add tests for forced logits processors * update gen utils tests * add forced option to tf generate * remove adjust_logits method from tf models * update adjust_logits for marian * delete _force_token_id_to_be_generated method * style * import warnings * pass max_length to _get_logits_processor * set forced_eos_token_id to None * set forced attributes in conf utils * typo * fix rag generate * add forced_eos_token_id in rag config * remove force_bos_token_to_be_generated from BartConfig * remove _force_token_ids_generation from FSMT * nit * fix negative constant * apply suggestions from code review
This commit is contained in:
@@ -26,6 +26,8 @@ if is_torch_available():
|
||||
from transformers import BartForConditionalGeneration, BartTokenizer, top_k_top_p_filtering
|
||||
from transformers.generation_beam_search import BeamSearchScorer
|
||||
from transformers.generation_logits_process import (
|
||||
ForcedBOSTokenLogitsProcessor,
|
||||
ForcedEOSTokenLogitsProcessor,
|
||||
HammingDiversityLogitsProcessor,
|
||||
LogitsProcessorList,
|
||||
MinLengthLogitsProcessor,
|
||||
@@ -70,7 +72,14 @@ class GenerationTesterMixin:
|
||||
return config, input_ids, attention_mask, max_length
|
||||
|
||||
@staticmethod
|
||||
def _get_logits_processor_and_kwargs(input_length, eos_token_id, diversity_penalty=None):
|
||||
def _get_logits_processor_and_kwargs(
|
||||
input_length,
|
||||
eos_token_id,
|
||||
forced_bos_token_id=None,
|
||||
forced_eos_token_id=None,
|
||||
max_length=None,
|
||||
diversity_penalty=None,
|
||||
):
|
||||
process_kwargs = {
|
||||
"min_length": input_length + 1,
|
||||
"bad_words_ids": [[1, 0]],
|
||||
@@ -92,6 +101,18 @@ class GenerationTesterMixin:
|
||||
if eos_token_id is not None
|
||||
else []
|
||||
)
|
||||
+ (
|
||||
[
|
||||
ForcedBOSTokenLogitsProcessor(forced_bos_token_id),
|
||||
]
|
||||
if forced_bos_token_id is not None
|
||||
else []
|
||||
)
|
||||
+ (
|
||||
[ForcedEOSTokenLogitsProcessor(max_length, forced_eos_token_id)]
|
||||
if forced_eos_token_id is not None
|
||||
else []
|
||||
)
|
||||
+ [
|
||||
NoBadWordsLogitsProcessor(process_kwargs["bad_words_ids"], eos_token_id),
|
||||
NoRepeatNGramLogitsProcessor(process_kwargs["no_repeat_ngram_size"]),
|
||||
@@ -182,13 +203,17 @@ class GenerationTesterMixin:
|
||||
output_hidden_states=False,
|
||||
return_dict_in_generate=False,
|
||||
):
|
||||
if model.config.is_encoder_decoder:
|
||||
max_length = 4
|
||||
logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
|
||||
input_ids.shape[-1], model.config.eos_token_id
|
||||
input_ids.shape[-1],
|
||||
eos_token_id=model.config.eos_token_id,
|
||||
forced_bos_token_id=model.config.forced_bos_token_id,
|
||||
forced_eos_token_id=model.config.forced_eos_token_id,
|
||||
max_length=max_length,
|
||||
)
|
||||
|
||||
kwargs = {}
|
||||
if model.config.is_encoder_decoder:
|
||||
max_length = 4
|
||||
|
||||
output_generate = model.generate(
|
||||
input_ids,
|
||||
@@ -544,14 +569,19 @@ class GenerationTesterMixin:
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
|
||||
input_ids.shape[-1], model.config.eos_token_id
|
||||
)
|
||||
logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
max_length = 4
|
||||
|
||||
process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
|
||||
input_ids.shape[-1],
|
||||
model.config.eos_token_id,
|
||||
forced_bos_token_id=model.config.forced_bos_token_id,
|
||||
forced_eos_token_id=model.config.forced_eos_token_id,
|
||||
max_length=max_length,
|
||||
)
|
||||
logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1)
|
||||
|
||||
# check `generate()` and `sample()` are equal
|
||||
output_sample, output_generate = self._sample_generate(
|
||||
model=model,
|
||||
@@ -586,14 +616,18 @@ class GenerationTesterMixin:
|
||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
||||
config.use_cache = False
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
|
||||
input_ids.shape[-1], model.config.eos_token_id
|
||||
)
|
||||
logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
max_length = 4
|
||||
|
||||
process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
|
||||
input_ids.shape[-1],
|
||||
model.config.eos_token_id,
|
||||
forced_bos_token_id=model.config.forced_bos_token_id,
|
||||
forced_eos_token_id=model.config.forced_eos_token_id,
|
||||
max_length=max_length,
|
||||
)
|
||||
logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1)
|
||||
|
||||
output_sample, output_generate = self._sample_generate(
|
||||
model=model,
|
||||
input_ids=input_ids,
|
||||
@@ -630,14 +664,19 @@ class GenerationTesterMixin:
|
||||
# shorter than `max_length` can be generated which could lead to flaky circle ci
|
||||
# failures if the top `num_return_sequences` beams are all shorter than the longest beam
|
||||
config.eos_token_id = None
|
||||
config.forced_eos_token_id = None
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
|
||||
logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
|
||||
input_ids.shape[-1], config.eos_token_id
|
||||
)
|
||||
if model.config.is_encoder_decoder:
|
||||
max_length = 4
|
||||
|
||||
logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
|
||||
input_ids.shape[-1],
|
||||
config.eos_token_id,
|
||||
config.forced_bos_token_id,
|
||||
config.forced_eos_token_id,
|
||||
max_length,
|
||||
)
|
||||
beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(input_ids.shape[0], max_length)
|
||||
|
||||
# check `generate()` and `beam_search()` are equal
|
||||
@@ -684,13 +723,19 @@ class GenerationTesterMixin:
|
||||
# shorter than `max_length` can be generated which could lead to flaky circle ci
|
||||
# failures if the top `num_return_sequences` beams are all shorter than the longest beam
|
||||
config.eos_token_id = None
|
||||
config.forced_eos_token_id = None
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
|
||||
input_ids.shape[-1], config.eos_token_id
|
||||
)
|
||||
if model.config.is_encoder_decoder:
|
||||
max_length = 4
|
||||
|
||||
logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
|
||||
input_ids.shape[-1],
|
||||
config.eos_token_id,
|
||||
config.forced_bos_token_id,
|
||||
config.forced_eos_token_id,
|
||||
max_length,
|
||||
)
|
||||
beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(input_ids.shape[0], max_length)
|
||||
output_generate, output_beam_search = self._beam_search_generate(
|
||||
model=model,
|
||||
@@ -732,19 +777,24 @@ class GenerationTesterMixin:
|
||||
# shorter than `max_length` can be generated which could lead to flaky circle ci
|
||||
# failures if the top `num_return_sequences` beams are all shorter than the longest beam
|
||||
config.eos_token_id = None
|
||||
config.forced_eos_token_id = None
|
||||
|
||||
if not hasattr(config, "use_cache"):
|
||||
# only relevant if model has "use_cache"
|
||||
return
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
|
||||
logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
|
||||
input_ids.shape[-1], config.eos_token_id
|
||||
)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
max_length = 4
|
||||
|
||||
logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
|
||||
input_ids.shape[-1],
|
||||
config.eos_token_id,
|
||||
config.forced_bos_token_id,
|
||||
config.forced_eos_token_id,
|
||||
max_length,
|
||||
)
|
||||
|
||||
beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(input_ids.shape[0], max_length)
|
||||
|
||||
config.use_cache = True
|
||||
@@ -780,6 +830,7 @@ class GenerationTesterMixin:
|
||||
# shorter than `max_length` can be generated which could lead to flaky circle ci
|
||||
# failures if the top `num_return_sequences` beams are all shorter than the longest beam
|
||||
config.eos_token_id = None
|
||||
config.forced_eos_token_id = None
|
||||
|
||||
logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1)
|
||||
|
||||
@@ -819,6 +870,7 @@ class GenerationTesterMixin:
|
||||
# shorter than `max_length` can be generated which could lead to flaky circle ci
|
||||
# failures if the top `num_return_sequences` beams are all shorter than the longest beam
|
||||
config.eos_token_id = None
|
||||
config.forced_eos_token_id = None
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1)
|
||||
@@ -892,16 +944,22 @@ class GenerationTesterMixin:
|
||||
# shorter than `max_length` can be generated which could lead to flaky circle ci
|
||||
# failures if the top `num_return_sequences` beams are all shorter than the longest beam
|
||||
config.eos_token_id = None
|
||||
|
||||
logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
|
||||
input_ids.shape[-1], config.eos_token_id, diversity_penalty=2.0
|
||||
)
|
||||
config.forced_eos_token_id = None
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
|
||||
# check `generate()` and `group_beam_search()` are equal
|
||||
if model.config.is_encoder_decoder:
|
||||
max_length = 4
|
||||
|
||||
logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
|
||||
input_ids.shape[-1],
|
||||
config.eos_token_id,
|
||||
config.forced_bos_token_id,
|
||||
config.forced_eos_token_id,
|
||||
max_length,
|
||||
diversity_penalty=2.0,
|
||||
)
|
||||
|
||||
# check `generate()` and `group_beam_search()` are equal
|
||||
beam_kwargs, beam_scorer = self._get_diverse_beam_scorer_and_kwargs(input_ids.shape[0], max_length)
|
||||
output_generate, output_group_beam_search = self._group_beam_search_generate(
|
||||
model=model,
|
||||
@@ -943,16 +1001,22 @@ class GenerationTesterMixin:
|
||||
# shorter than `max_length` can be generated which could lead to flaky circle ci
|
||||
# failures if the top `num_return_sequences` beams are all shorter than the longest beam
|
||||
config.eos_token_id = None
|
||||
config.forced_eos_token_id = None
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
if model.config.is_encoder_decoder:
|
||||
max_length = 4
|
||||
|
||||
logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
|
||||
input_ids.shape[-1], config.eos_token_id, diversity_penalty=2.0
|
||||
input_ids.shape[-1],
|
||||
config.eos_token_id,
|
||||
config.forced_bos_token_id,
|
||||
config.forced_eos_token_id,
|
||||
max_length,
|
||||
diversity_penalty=2.0,
|
||||
)
|
||||
|
||||
num_return_sequences = 1
|
||||
if model.config.is_encoder_decoder:
|
||||
max_length = 4
|
||||
beam_kwargs, beam_scorer = self._get_diverse_beam_scorer_and_kwargs(
|
||||
input_ids.shape[0], max_length, num_return_sequences=num_return_sequences
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user