remove adjust_logits_during_generation method (#10087)

* add forced logits processors

* delete adjust_logits method

* add forced_eos_token_id argument in config

* add tests for forced logits processors

* update gen utils tests

* add forced option to tf generate

* remove adjust_logits method from tf models

* update adjust_logits for marian

* delete _force_token_id_to_be_generated method

* style

* import warnings

* pass max_length to _get_logits_processor

* set forced_eos_token_id to None

* set forced attributes in conf utils

* typo

* fix rag generate

* add forced_eos_token_id in rag config

* remove force_bos_token_to_be_generated from BartConfig

* remove _force_token_ids_generation from FSMT

* nit

* fix negative constant

* apply suggestions from code review
This commit is contained in:
Suraj Patil
2021-02-10 22:39:09 +05:30
committed by GitHub
parent 22a32cf485
commit c130e67dce
29 changed files with 335 additions and 166 deletions

View File

@@ -74,6 +74,9 @@ RAG_CONFIG_DOC = r"""
:obj:`context_attention_mask` are returned. See returned tensors for more detail.
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not the model should return the last key/values attentions (not used by all models).
forced_eos_token_id (:obj:`int`, `optional`):
The id of the token to force as the last generated token when :obj:`max_length` is reached. Usually set to
:obj:`eos_token_id`.
"""
@@ -110,6 +113,7 @@ class RagConfig(PretrainedConfig):
do_marginalize=False,
output_retrieved=False,
use_cache=True,
forced_eos_token_id=None,
**kwargs
):
super().__init__(
@@ -117,6 +121,7 @@ class RagConfig(PretrainedConfig):
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
decoder_start_token_id=decoder_start_token_id,
forced_eos_token_id=forced_eos_token_id,
is_encoder_decoder=is_encoder_decoder,
prefix=prefix,
vocab_size=vocab_size,
@@ -161,6 +166,9 @@ class RagConfig(PretrainedConfig):
self.use_cache = use_cache
if self.forced_eos_token_id is None:
self.forced_eos_token_id = getattr(self.generator, "forced_eos_token_id", None)
@classmethod
def from_question_encoder_generator_configs(
cls, question_encoder_config: PretrainedConfig, generator_config: PretrainedConfig, **kwargs

View File

@@ -1089,9 +1089,6 @@ class RagTokenForGeneration(RagPreTrainedModel):
def set_retriever(self, retriever: RagRetriever):
self.rag.retriever = retriever
def adjust_logits_during_generation(self, logits, cur_len, max_length):
return self.rag.generator.adjust_logits_during_generation(logits, cur_len=cur_len, max_length=max_length)
def prepare_inputs_for_generation(
self,
decoder_input_ids,
@@ -1313,6 +1310,8 @@ class RagTokenForGeneration(RagPreTrainedModel):
decoder_start_token_id=None,
n_docs=None,
prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]] = None,
forced_bos_token_id: Optional[int] = None,
forced_eos_token_id: Optional[int] = None,
**model_kwargs
):
"""
@@ -1403,6 +1402,12 @@ class RagTokenForGeneration(RagPreTrainedModel):
conditioned on the previously generated tokens :obj:`inputs_ids` and the batch ID :obj:`batch_id`. This
argument is useful for constrained generation conditioned on the prefix, as described in
`Autoregressive Entity Retrieval <https://arxiv.org/abs/2010.00904>`__.
forced_bos_token_id (:obj:`int`, `optional`):
The id of the token to force as the first generated token after the :obj:`decoder_start_token_id`.
Useful for multilingual models like :doc:`mBART <../model_doc/mbart>` where the first generated token
needs to be the target language token.
forced_eos_token_id (:obj:`int`, `optional`):
The id of the token to force as the last generated token when :obj:`max_length` is reached.
Return:
:obj:`torch.LongTensor` of shape :obj:`(batch_size * num_return_sequences, sequence_length)`: The generated
@@ -1498,7 +1503,10 @@ class RagTokenForGeneration(RagPreTrainedModel):
encoder_input_ids=context_input_ids,
bad_words_ids=bad_words_ids,
min_length=min_length,
max_length=max_length,
eos_token_id=eos_token_id,
forced_bos_token_id=forced_bos_token_id,
forced_eos_token_id=forced_eos_token_id,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
num_beams=num_beams,
num_beam_groups=num_beam_groups,