remove adjust_logits_during_generation method (#10087)
* add forced logits processors * delete adjust_logits method * add forced_eos_token_id argument in config * add tests for forced logits processors * update gen utils tests * add forced option to tf generate * remove adjust_logits method from tf models * update adjust_logits for marian * delete _force_token_id_to_be_generated method * style * import warnings * pass max_length to _get_logits_processor * set forced_eos_token_id to None * set forced attributes in conf utils * typo * fix rag generate * add forced_eos_token_id in rag config * remove force_bos_token_to_be_generated from BartConfig * remove _force_token_ids_generation from FSMT * nit * fix negative constant * apply suggestions from code review
This commit is contained in:
@@ -24,6 +24,8 @@ from .file_utils import ModelOutput
|
||||
from .generation_beam_search import BeamScorer, BeamSearchScorer
|
||||
from .generation_logits_process import (
|
||||
EncoderNoRepeatNGramLogitsProcessor,
|
||||
ForcedBOSTokenLogitsProcessor,
|
||||
ForcedEOSTokenLogitsProcessor,
|
||||
HammingDiversityLogitsProcessor,
|
||||
LogitsProcessorList,
|
||||
MinLengthLogitsProcessor,
|
||||
@@ -542,7 +544,10 @@ class GenerationMixin:
|
||||
encoder_input_ids: torch.LongTensor,
|
||||
bad_words_ids: List[List[int]],
|
||||
min_length: int,
|
||||
max_length: int,
|
||||
eos_token_id: int,
|
||||
forced_bos_token_id: int,
|
||||
forced_eos_token_id: int,
|
||||
prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]],
|
||||
num_beams: int,
|
||||
num_beam_groups: int,
|
||||
@@ -567,6 +572,12 @@ class GenerationMixin:
|
||||
min_length = min_length if min_length is not None else self.config.min_length
|
||||
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
||||
diversity_penalty = diversity_penalty if diversity_penalty is not None else self.config.diversity_penalty
|
||||
forced_bos_token_id = (
|
||||
forced_bos_token_id if forced_bos_token_id is not None else self.config.forced_bos_token_id
|
||||
)
|
||||
forced_eos_token_id = (
|
||||
forced_eos_token_id if forced_eos_token_id is not None else self.config.forced_eos_token_id
|
||||
)
|
||||
# instantiate processors list
|
||||
processors = LogitsProcessorList()
|
||||
|
||||
@@ -595,6 +606,10 @@ class GenerationMixin:
|
||||
processors.append(MinLengthLogitsProcessor(min_length, eos_token_id))
|
||||
if prefix_allowed_tokens_fn is not None:
|
||||
processors.append(PrefixConstrainedLogitsProcessor(prefix_allowed_tokens_fn, num_beams))
|
||||
if forced_bos_token_id is not None:
|
||||
processors.append(ForcedBOSTokenLogitsProcessor(forced_bos_token_id))
|
||||
if forced_eos_token_id is not None:
|
||||
processors.append(ForcedEOSTokenLogitsProcessor(max_length, forced_eos_token_id))
|
||||
return processors
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -627,6 +642,8 @@ class GenerationMixin:
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
output_scores: Optional[bool] = None,
|
||||
return_dict_in_generate: Optional[bool] = None,
|
||||
forced_bos_token_id: Optional[int] = None,
|
||||
forced_eos_token_id: Optional[int] = None,
|
||||
**model_kwargs,
|
||||
) -> Union[GreedySearchOutput, SampleOutput, BeamSearchOutput, BeamSampleOutput, torch.LongTensor]:
|
||||
r"""
|
||||
@@ -720,6 +737,12 @@ class GenerationMixin:
|
||||
Whether or not to return the prediction scores. See ``scores`` under returned tensors for more details.
|
||||
return_dict_in_generate (:obj:`bool`, `optional`, defaults to `False`):
|
||||
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
|
||||
forced_bos_token_id (:obj:`int`, `optional`):
|
||||
The id of the token to force as the first generated token after the :obj:`decoder_start_token_id`.
|
||||
Useful for multilingual models like :doc:`mBART <../model_doc/mbart>` where the first generated token
|
||||
needs to be the target language token.
|
||||
forced_eos_token_id (:obj:`int`, `optional`):
|
||||
The id of the token to force as the last generated token when :obj:`max_length` is reached.
|
||||
|
||||
model_kwargs:
|
||||
Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model. If the
|
||||
@@ -888,7 +911,10 @@ class GenerationMixin:
|
||||
encoder_input_ids=encoder_input_ids,
|
||||
bad_words_ids=bad_words_ids,
|
||||
min_length=min_length,
|
||||
max_length=max_length,
|
||||
eos_token_id=eos_token_id,
|
||||
forced_bos_token_id=forced_bos_token_id,
|
||||
forced_eos_token_id=forced_eos_token_id,
|
||||
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
|
||||
num_beams=num_beams,
|
||||
num_beam_groups=num_beam_groups,
|
||||
@@ -1611,7 +1637,8 @@ class GenerationMixin:
|
||||
)
|
||||
next_token_logits = outputs.logits[:, -1, :]
|
||||
|
||||
# adjust tokens for Bart, *e.g.*
|
||||
# hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
|
||||
# cannot be generated both before and after the `F.log_softmax` operation.
|
||||
next_token_logits = self.adjust_logits_during_generation(
|
||||
next_token_logits, cur_len=cur_len, max_length=max_length
|
||||
)
|
||||
@@ -1866,7 +1893,8 @@ class GenerationMixin:
|
||||
)
|
||||
next_token_logits = outputs.logits[:, -1, :]
|
||||
|
||||
# adjust token scores (a no-op by default)
|
||||
# hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
|
||||
# cannot be generated both before and after the `F.log_softmax` operation.
|
||||
next_token_logits = self.adjust_logits_during_generation(
|
||||
next_token_logits, cur_len=cur_len, max_length=max_length
|
||||
)
|
||||
@@ -2150,7 +2178,8 @@ class GenerationMixin:
|
||||
# select outputs of beams of current group only
|
||||
next_token_logits = outputs.logits[batch_group_indices, -1, :]
|
||||
|
||||
# adjust tokens for Bart, *e.g.*
|
||||
# hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
|
||||
# cannot be generated both before and after the `F.log_softmax` operation.
|
||||
next_token_logits = self.adjust_logits_during_generation(
|
||||
next_token_logits, cur_len=cur_len, max_length=max_length
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user