remove adjust_logits_during_generation method (#10087)
* add forced logits processors * delete adjust_logits method * add forced_eos_token_id argument in config * add tests for forced logits processors * update gen utils tests * add forced option to tf generate * remove adjust_logits method from tf models * update adjust_logits for marian * delete _force_token_id_to_be_generated method * style * import warnings * pass max_length to _get_logits_processor * set forced_eos_token_id to None * set forced attributes in conf utils * typo * fix rag generate * add forced_eos_token_id in rag config * remove force_bos_token_to_be_generated from BartConfig * remove _force_token_ids_generation from FSMT * nit * fix negative constant * apply suggestions from code review
This commit is contained in:
@@ -84,6 +84,9 @@ class BlenderbotSmallConfig(PretrainedConfig):
|
||||
Scale embeddings by diving by sqrt(d_model).
|
||||
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models)
|
||||
forced_eos_token_id (:obj:`int`, `optional`, defaults to 2):
|
||||
The id of the token to force as the last generated token when :obj:`max_length` is reached. Usually set to
|
||||
:obj:`eos_token_id`.
|
||||
|
||||
Example::
|
||||
|
||||
@@ -128,6 +131,7 @@ class BlenderbotSmallConfig(PretrainedConfig):
|
||||
pad_token_id=0,
|
||||
bos_token_id=1,
|
||||
eos_token_id=2,
|
||||
forced_eos_token_id=2,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(
|
||||
@@ -136,6 +140,7 @@ class BlenderbotSmallConfig(PretrainedConfig):
|
||||
eos_token_id=eos_token_id,
|
||||
is_encoder_decoder=is_encoder_decoder,
|
||||
decoder_start_token_id=decoder_start_token_id,
|
||||
forced_eos_token_id=forced_eos_token_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@@ -1310,16 +1310,6 @@ class BlenderbotSmallForConditionalGeneration(BlenderbotSmallPreTrainedModel):
|
||||
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
||||
}
|
||||
|
||||
def adjust_logits_during_generation(self, logits, cur_len, max_length):
|
||||
if cur_len == max_length - 1 and self.config.eos_token_id is not None:
|
||||
self._force_token_id_to_be_generated(logits, self.config.eos_token_id)
|
||||
return logits
|
||||
|
||||
@staticmethod
|
||||
def _force_token_id_to_be_generated(scores, token_id) -> None:
|
||||
"""force one of token_ids to be generated by setting prob of all other tokens to 0 (logprob=-float("inf"))"""
|
||||
scores[:, [x for x in range(scores.shape[1]) if x != token_id]] = -float("inf")
|
||||
|
||||
@staticmethod
|
||||
def _reorder_cache(past, beam_idx):
|
||||
reordered_past = ()
|
||||
|
||||
@@ -1452,10 +1452,3 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel
|
||||
+ layer_past_key_values[2:],
|
||||
)
|
||||
return (past[0], reordered_past)
|
||||
|
||||
def adjust_logits_during_generation(self, logits, cur_len, max_length):
|
||||
if cur_len == max_length - 1:
|
||||
vocab_range = tf.constant(range(self.config.vocab_size))
|
||||
return tf.where(vocab_range != self.config.eos_token_id, LARGE_NEGATIVE, logits)
|
||||
else:
|
||||
return logits
|
||||
|
||||
Reference in New Issue
Block a user