remove adjust_logits_during_generation method (#10087)
* add forced logits processors * delete adjust_logits method * add forced_eos_token_id argument in config * add tests for forced logits processors * update gen utils tests * add forced option to tf generate * remove adjust_logits method from tf models * update adjust_logits for marian * delete _force_token_id_to_be_generated method * style * import warnings * pass max_length to _get_logits_processor * set forced_eos_token_id to None * set forced attributes in conf utils * typo * fix rag generate * add forced_eos_token_id in rag config * remove force_bos_token_to_be_generated from BartConfig * remove _force_token_ids_generation from FSMT * nit * fix negative constant * apply suggestions from code review
This commit is contained in:
@@ -67,6 +67,8 @@ class TFGenerationMixin:
|
||||
attention_mask=None,
|
||||
decoder_start_token_id=None,
|
||||
use_cache=None,
|
||||
forced_bos_token_id=None,
|
||||
forced_eos_token_id=None,
|
||||
):
|
||||
r"""
|
||||
Generates sequences for models with a language modeling head. The method currently supports greedy decoding,
|
||||
@@ -137,6 +139,12 @@ class TFGenerationMixin:
|
||||
use_cache: (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Whether or not the model should use the past last key/values attentions (if applicable to the model) to
|
||||
speed up decoding.
|
||||
forced_bos_token_id (:obj:`int`, `optional`):
|
||||
The id of the token to force as the first generated token after the :obj:`decoder_start_token_id`.
|
||||
Useful for multilingual models like :doc:`mBART <../model_doc/mbart>` where the first generated token
|
||||
needs to be the target language token.
|
||||
forced_eos_token_id (:obj:`int`, `optional`):
|
||||
The id of the token to force as the last generated token when :obj:`max_length` is reached.
|
||||
model_specific_kwargs:
|
||||
Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model.
|
||||
|
||||
@@ -214,6 +222,12 @@ class TFGenerationMixin:
|
||||
decoder_start_token_id = (
|
||||
decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id
|
||||
)
|
||||
forced_bos_token_id = (
|
||||
forced_bos_token_id if forced_bos_token_id is not None else self.config.forced_bos_token_id
|
||||
)
|
||||
forced_eos_token_id = (
|
||||
forced_eos_token_id if forced_eos_token_id is not None else self.config.forced_eos_token_id
|
||||
)
|
||||
|
||||
if input_ids is not None:
|
||||
batch_size = shape_list(input_ids)[0] # overridden by the input batch_size
|
||||
@@ -380,6 +394,8 @@ class TFGenerationMixin:
|
||||
encoder_outputs=encoder_outputs,
|
||||
attention_mask=attention_mask,
|
||||
use_cache=use_cache,
|
||||
forced_bos_token_id=forced_bos_token_id,
|
||||
forced_eos_token_id=forced_eos_token_id,
|
||||
)
|
||||
else:
|
||||
output = self._generate_no_beam_search(
|
||||
@@ -591,6 +607,8 @@ class TFGenerationMixin:
|
||||
encoder_outputs,
|
||||
attention_mask,
|
||||
use_cache,
|
||||
forced_bos_token_id,
|
||||
forced_eos_token_id,
|
||||
):
|
||||
"""Generate sequences for each example with beam search."""
|
||||
|
||||
@@ -641,7 +659,11 @@ class TFGenerationMixin:
|
||||
|
||||
if self.config.is_encoder_decoder and do_sample is False:
|
||||
next_token_logits = self.adjust_logits_during_generation(
|
||||
next_token_logits, cur_len=cur_len, max_length=max_length
|
||||
next_token_logits,
|
||||
cur_len=cur_len,
|
||||
max_length=max_length,
|
||||
forced_bos_token_id=forced_bos_token_id,
|
||||
forced_eos_token_id=forced_eos_token_id,
|
||||
)
|
||||
# calculate log softmax score
|
||||
scores = tf.nn.log_softmax(next_token_logits, axis=-1) # (batch_size * num_beams, vocab_size)
|
||||
@@ -893,12 +915,21 @@ class TFGenerationMixin:
|
||||
def _reorder_cache(past, beam_idx):
|
||||
return tuple(tf.gather(layer_past, beam_idx, axis=1) for layer_past in past)
|
||||
|
||||
def adjust_logits_during_generation(self, logits, **kwargs):
|
||||
def adjust_logits_during_generation(
|
||||
self, logits, cur_len, max_length, forced_bos_token_id, forced_eos_token_id, **kwargs
|
||||
):
|
||||
"""
|
||||
Implement in subclasses of :class:`~transformers.PreTrainedModel` for custom behavior to adjust the logits in
|
||||
the generate method.
|
||||
"""
|
||||
return logits
|
||||
if cur_len == 1 and forced_bos_token_id is not None:
|
||||
vocab_range = tf.constant(range(self.config.vocab_size))
|
||||
return tf.where(vocab_range != forced_bos_token_id, -1e8, logits)
|
||||
elif cur_len == max_length - 1 and forced_eos_token_id is not None:
|
||||
vocab_range = tf.constant(range(self.config.vocab_size))
|
||||
return tf.where(vocab_range != forced_eos_token_id, -1e8, logits)
|
||||
else:
|
||||
return logits
|
||||
|
||||
|
||||
def _create_next_token_logits_penalties(input_ids, logits, repetition_penalty):
|
||||
|
||||
Reference in New Issue
Block a user