remove adjust_logits_during_generation method (#10087)

* add forced logits processors

* delete adjust_logits method

* add forced_eos_token_id argument in config

* add tests for forced logits processors

* update gen utils tests

* add forced option to tf generate

* remove adjust_logits method from tf models

* update adjust_logits for marian

* delete _force_token_id_to_be_generated method

* style

* import warnings

* pass max_length to _get_logits_processor

* set forced_eos_token_id to None

* set forced attributes in conf utils

* typo

* fix rag generate

* add forced_eos_token_id in rag config

* remove force_bos_token_to_be_generated from BartConfig

* remove _force_token_ids_generation from FSMT

* nit

* fix negative constant

* apply suggestions from code review
This commit is contained in:
Suraj Patil
2021-02-10 22:39:09 +05:30
committed by GitHub
parent 22a32cf485
commit c130e67dce
29 changed files with 335 additions and 166 deletions

View File

@@ -24,6 +24,8 @@ from .file_utils import ModelOutput
from .generation_beam_search import BeamScorer, BeamSearchScorer
from .generation_logits_process import (
EncoderNoRepeatNGramLogitsProcessor,
ForcedBOSTokenLogitsProcessor,
ForcedEOSTokenLogitsProcessor,
HammingDiversityLogitsProcessor,
LogitsProcessorList,
MinLengthLogitsProcessor,
@@ -542,7 +544,10 @@ class GenerationMixin:
encoder_input_ids: torch.LongTensor,
bad_words_ids: List[List[int]],
min_length: int,
max_length: int,
eos_token_id: int,
forced_bos_token_id: int,
forced_eos_token_id: int,
prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]],
num_beams: int,
num_beam_groups: int,
@@ -567,6 +572,12 @@ class GenerationMixin:
min_length = min_length if min_length is not None else self.config.min_length
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
diversity_penalty = diversity_penalty if diversity_penalty is not None else self.config.diversity_penalty
forced_bos_token_id = (
forced_bos_token_id if forced_bos_token_id is not None else self.config.forced_bos_token_id
)
forced_eos_token_id = (
forced_eos_token_id if forced_eos_token_id is not None else self.config.forced_eos_token_id
)
# instantiate processors list
processors = LogitsProcessorList()
@@ -595,6 +606,10 @@ class GenerationMixin:
processors.append(MinLengthLogitsProcessor(min_length, eos_token_id))
if prefix_allowed_tokens_fn is not None:
processors.append(PrefixConstrainedLogitsProcessor(prefix_allowed_tokens_fn, num_beams))
if forced_bos_token_id is not None:
processors.append(ForcedBOSTokenLogitsProcessor(forced_bos_token_id))
if forced_eos_token_id is not None:
processors.append(ForcedEOSTokenLogitsProcessor(max_length, forced_eos_token_id))
return processors
@torch.no_grad()
@@ -627,6 +642,8 @@ class GenerationMixin:
output_hidden_states: Optional[bool] = None,
output_scores: Optional[bool] = None,
return_dict_in_generate: Optional[bool] = None,
forced_bos_token_id: Optional[int] = None,
forced_eos_token_id: Optional[int] = None,
**model_kwargs,
) -> Union[GreedySearchOutput, SampleOutput, BeamSearchOutput, BeamSampleOutput, torch.LongTensor]:
r"""
@@ -720,6 +737,12 @@ class GenerationMixin:
Whether or not to return the prediction scores. See ``scores`` under returned tensors for more details.
return_dict_in_generate (:obj:`bool`, `optional`, defaults to `False`):
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
forced_bos_token_id (:obj:`int`, `optional`):
The id of the token to force as the first generated token after the :obj:`decoder_start_token_id`.
Useful for multilingual models like :doc:`mBART <../model_doc/mbart>` where the first generated token
needs to be the target language token.
forced_eos_token_id (:obj:`int`, `optional`):
The id of the token to force as the last generated token when :obj:`max_length` is reached.
model_kwargs:
Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model. If the
@@ -888,7 +911,10 @@ class GenerationMixin:
encoder_input_ids=encoder_input_ids,
bad_words_ids=bad_words_ids,
min_length=min_length,
max_length=max_length,
eos_token_id=eos_token_id,
forced_bos_token_id=forced_bos_token_id,
forced_eos_token_id=forced_eos_token_id,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
num_beams=num_beams,
num_beam_groups=num_beam_groups,
@@ -1611,7 +1637,8 @@ class GenerationMixin:
)
next_token_logits = outputs.logits[:, -1, :]
# adjust tokens for Bart, *e.g.*
# hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
# cannot be generated both before and after the `F.log_softmax` operation.
next_token_logits = self.adjust_logits_during_generation(
next_token_logits, cur_len=cur_len, max_length=max_length
)
@@ -1866,7 +1893,8 @@ class GenerationMixin:
)
next_token_logits = outputs.logits[:, -1, :]
# adjust token scores (a no-op by default)
# hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
# cannot be generated both before and after the `F.log_softmax` operation.
next_token_logits = self.adjust_logits_during_generation(
next_token_logits, cur_len=cur_len, max_length=max_length
)
@@ -2150,7 +2178,8 @@ class GenerationMixin:
# select outputs of beams of current group only
next_token_logits = outputs.logits[batch_group_indices, -1, :]
# adjust tokens for Bart, *e.g.*
# hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
# cannot be generated both before and after the `F.log_softmax` operation.
next_token_logits = self.adjust_logits_during_generation(
next_token_logits, cur_len=cur_len, max_length=max_length
)