remove adjust_logits_during_generation method (#10087)
* add forced logits processors * delete adjust_logits method * add forced_eos_token_id argument in config * add tests for forced logits processors * update gen utils tests * add forced option to tf generate * remove adjust_logits method from tf models * update adjust_logits for marian * delete _force_token_id_to_be_generated method * style * import warnings * pass max_length to _get_logits_processor * set forced_eos_token_id to None * set forced attributes in conf utils * typo * fix rag generate * add forced_eos_token_id in rag config * remove force_bos_token_to_be_generated from BartConfig * remove _force_token_ids_generation from FSMT * nit * fix negative constant * apply suggestions from code review
This commit is contained in:
@@ -74,6 +74,9 @@ RAG_CONFIG_DOC = r"""
|
||||
:obj:`context_attention_mask` are returned. See returned tensors for more detail.
|
||||
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models).
|
||||
forced_eos_token_id (:obj:`int`, `optional`):
|
||||
The id of the token to force as the last generated token when :obj:`max_length` is reached. Usually set to
|
||||
:obj:`eos_token_id`.
|
||||
"""
|
||||
|
||||
|
||||
@@ -110,6 +113,7 @@ class RagConfig(PretrainedConfig):
|
||||
do_marginalize=False,
|
||||
output_retrieved=False,
|
||||
use_cache=True,
|
||||
forced_eos_token_id=None,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(
|
||||
@@ -117,6 +121,7 @@ class RagConfig(PretrainedConfig):
|
||||
pad_token_id=pad_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
decoder_start_token_id=decoder_start_token_id,
|
||||
forced_eos_token_id=forced_eos_token_id,
|
||||
is_encoder_decoder=is_encoder_decoder,
|
||||
prefix=prefix,
|
||||
vocab_size=vocab_size,
|
||||
@@ -161,6 +166,9 @@ class RagConfig(PretrainedConfig):
|
||||
|
||||
self.use_cache = use_cache
|
||||
|
||||
if self.forced_eos_token_id is None:
|
||||
self.forced_eos_token_id = getattr(self.generator, "forced_eos_token_id", None)
|
||||
|
||||
@classmethod
|
||||
def from_question_encoder_generator_configs(
|
||||
cls, question_encoder_config: PretrainedConfig, generator_config: PretrainedConfig, **kwargs
|
||||
|
||||
@@ -1089,9 +1089,6 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
||||
def set_retriever(self, retriever: RagRetriever):
|
||||
self.rag.retriever = retriever
|
||||
|
||||
def adjust_logits_during_generation(self, logits, cur_len, max_length):
|
||||
return self.rag.generator.adjust_logits_during_generation(logits, cur_len=cur_len, max_length=max_length)
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
decoder_input_ids,
|
||||
@@ -1313,6 +1310,8 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
||||
decoder_start_token_id=None,
|
||||
n_docs=None,
|
||||
prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]] = None,
|
||||
forced_bos_token_id: Optional[int] = None,
|
||||
forced_eos_token_id: Optional[int] = None,
|
||||
**model_kwargs
|
||||
):
|
||||
"""
|
||||
@@ -1403,6 +1402,12 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
||||
conditioned on the previously generated tokens :obj:`inputs_ids` and the batch ID :obj:`batch_id`. This
|
||||
argument is useful for constrained generation conditioned on the prefix, as described in
|
||||
`Autoregressive Entity Retrieval <https://arxiv.org/abs/2010.00904>`__.
|
||||
forced_bos_token_id (:obj:`int`, `optional`):
|
||||
The id of the token to force as the first generated token after the :obj:`decoder_start_token_id`.
|
||||
Useful for multilingual models like :doc:`mBART <../model_doc/mbart>` where the first generated token
|
||||
needs to be the target language token.
|
||||
forced_eos_token_id (:obj:`int`, `optional`):
|
||||
The id of the token to force as the last generated token when :obj:`max_length` is reached.
|
||||
|
||||
Return:
|
||||
:obj:`torch.LongTensor` of shape :obj:`(batch_size * num_return_sequences, sequence_length)`: The generated
|
||||
@@ -1498,7 +1503,10 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
||||
encoder_input_ids=context_input_ids,
|
||||
bad_words_ids=bad_words_ids,
|
||||
min_length=min_length,
|
||||
max_length=max_length,
|
||||
eos_token_id=eos_token_id,
|
||||
forced_bos_token_id=forced_bos_token_id,
|
||||
forced_eos_token_id=forced_eos_token_id,
|
||||
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
|
||||
num_beams=num_beams,
|
||||
num_beam_groups=num_beam_groups,
|
||||
|
||||
Reference in New Issue
Block a user