remove adjust_logits_during_generation method (#10087)
* add forced logits processors * delete adjust_logits method * add forced_eos_token_id argument in config * add tests for forced logits processors * update gen utils tests * add forced option to tf generate * remove adjust_logits method from tf models * update adjust_logits for marian * delete _force_token_id_to_be_generated method * style * import warnings * pass max_length to _get_logits_processor * set forced_eos_token_id to None * set forced attributes in conf utils * typo * fix rag generate * add forced_eos_token_id in rag config * remove force_bos_token_to_be_generated from BartConfig * remove _force_token_ids_generation from FSMT * nit * fix negative constant * apply suggestions from code review
This commit is contained in:
@@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" BART model configuration """
|
||||
import warnings
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...utils import logging
|
||||
@@ -72,9 +73,6 @@ class BartConfig(PretrainedConfig):
|
||||
just in case (e.g., 512 or 1024 or 2048).
|
||||
init_std (:obj:`float`, `optional`, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
force_bos_token_to_be_generated (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not to force BOS token to be generated at step 1 (after ``decoder_start_token_id``), only
|
||||
:obj:`True` for `bart-large-cnn`.
|
||||
encoder_layerdrop: (:obj:`float`, `optional`, defaults to 0.0):
|
||||
The LayerDrop probability for the encoder. See the `LayerDrop paper <see
|
||||
https://arxiv.org/abs/1909.11556>`__ for more details.
|
||||
@@ -89,6 +87,9 @@ class BartConfig(PretrainedConfig):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models).
|
||||
num_labels: (:obj:`int`, `optional`, defaults to 3):
|
||||
The number of labels to use in :class:`~transformers.BartForSequenceClassification`.
|
||||
forced_eos_token_id (:obj:`int`, `optional`, defaults to 2):
|
||||
The id of the token to force as the last generated token when :obj:`max_length` is reached. Usually set to
|
||||
:obj:`eos_token_id`.
|
||||
|
||||
Example::
|
||||
|
||||
@@ -127,7 +128,6 @@ class BartConfig(PretrainedConfig):
|
||||
classifier_dropout=0.0,
|
||||
scale_embedding=False,
|
||||
gradient_checkpointing=False,
|
||||
force_bos_token_to_be_generated=False,
|
||||
use_cache=True,
|
||||
num_labels=3,
|
||||
pad_token_id=1,
|
||||
@@ -135,6 +135,7 @@ class BartConfig(PretrainedConfig):
|
||||
eos_token_id=2,
|
||||
is_encoder_decoder=True,
|
||||
decoder_start_token_id=2,
|
||||
forced_eos_token_id=2,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(
|
||||
@@ -144,6 +145,7 @@ class BartConfig(PretrainedConfig):
|
||||
eos_token_id=eos_token_id,
|
||||
is_encoder_decoder=is_encoder_decoder,
|
||||
decoder_start_token_id=decoder_start_token_id,
|
||||
forced_eos_token_id=forced_eos_token_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -168,7 +170,14 @@ class BartConfig(PretrainedConfig):
|
||||
self.num_hidden_layers = encoder_layers
|
||||
self.gradient_checkpointing = gradient_checkpointing
|
||||
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
|
||||
self.force_bos_token_to_be_generated = force_bos_token_to_be_generated # only relevant for CNN
|
||||
|
||||
# ensure backward compatibilty for BART CNN models
|
||||
if self.forced_bos_token_id is None and kwargs.get("force_bos_token_to_be_generated", False):
|
||||
self.forced_bos_token_id = self.bos_token_id
|
||||
warnings.warn(
|
||||
f"Please make sure the config includes `forced_bos_token_id={self.bos_token_id}` in future versions."
|
||||
"The config can simply be saved and uploaded again to be fixed."
|
||||
)
|
||||
|
||||
@property
|
||||
def num_attention_heads(self) -> int:
|
||||
|
||||
@@ -1344,18 +1344,6 @@ class BartForConditionalGeneration(BartPretrainedModel):
|
||||
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
|
||||
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
|
||||
|
||||
def adjust_logits_during_generation(self, logits, cur_len, max_length):
|
||||
if cur_len == 1 and self.config.force_bos_token_to_be_generated:
|
||||
self._force_token_id_to_be_generated(logits, self.config.bos_token_id)
|
||||
elif cur_len == max_length - 1 and self.config.eos_token_id is not None:
|
||||
self._force_token_id_to_be_generated(logits, self.config.eos_token_id)
|
||||
return logits
|
||||
|
||||
@staticmethod
|
||||
def _force_token_id_to_be_generated(scores, token_id) -> None:
|
||||
"""force one of token_ids to be generated by setting prob of all other tokens to 0 (logprob=-float("inf"))"""
|
||||
scores[:, [x for x in range(scores.shape[1]) if x != token_id]] = -float("inf")
|
||||
|
||||
@staticmethod
|
||||
def _reorder_cache(past, beam_idx):
|
||||
reordered_past = ()
|
||||
|
||||
@@ -1444,13 +1444,3 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode
|
||||
+ layer_past_key_values[2:],
|
||||
)
|
||||
return (past[0], reordered_past)
|
||||
|
||||
def adjust_logits_during_generation(self, logits, cur_len, max_length):
|
||||
if cur_len == 1 and self.config.force_bos_token_to_be_generated:
|
||||
vocab_range = tf.constant(range(self.config.vocab_size))
|
||||
return tf.where(vocab_range != self.config.bos_token_id, LARGE_NEGATIVE, logits)
|
||||
elif cur_len == max_length - 1:
|
||||
vocab_range = tf.constant(range(self.config.vocab_size))
|
||||
return tf.where(vocab_range != self.config.eos_token_id, LARGE_NEGATIVE, logits)
|
||||
else:
|
||||
return logits
|
||||
|
||||
@@ -84,6 +84,9 @@ class BlenderbotConfig(PretrainedConfig):
|
||||
Scale embeddings by diving by sqrt(d_model).
|
||||
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models)
|
||||
forced_eos_token_id (:obj:`int`, `optional`, defaults to 2):
|
||||
The id of the token to force as the last generated token when :obj:`max_length` is reached. Usually set to
|
||||
:obj:`eos_token_id`.
|
||||
|
||||
Example::
|
||||
|
||||
@@ -129,6 +132,7 @@ class BlenderbotConfig(PretrainedConfig):
|
||||
bos_token_id=1,
|
||||
eos_token_id=2,
|
||||
encoder_no_repeat_ngram_size=3,
|
||||
forced_eos_token_id=2,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(
|
||||
@@ -138,6 +142,7 @@ class BlenderbotConfig(PretrainedConfig):
|
||||
is_encoder_decoder=is_encoder_decoder,
|
||||
decoder_start_token_id=decoder_start_token_id,
|
||||
encoder_no_repeat_ngram_size=encoder_no_repeat_ngram_size,
|
||||
forced_eos_token_id=forced_eos_token_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@@ -1335,16 +1335,6 @@ class BlenderbotForConditionalGeneration(BlenderbotPreTrainedModel):
|
||||
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
||||
}
|
||||
|
||||
def adjust_logits_during_generation(self, logits, cur_len, max_length):
|
||||
if cur_len == max_length - 1 and self.config.eos_token_id is not None:
|
||||
self._force_token_id_to_be_generated(logits, self.config.eos_token_id)
|
||||
return logits
|
||||
|
||||
@staticmethod
|
||||
def _force_token_id_to_be_generated(scores, token_id) -> None:
|
||||
"""force one of token_ids to be generated by setting prob of all other tokens to 0 (logprob=-float("inf"))"""
|
||||
scores[:, [x for x in range(scores.shape[1]) if x != token_id]] = -float("inf")
|
||||
|
||||
@staticmethod
|
||||
def _reorder_cache(past, beam_idx):
|
||||
reordered_past = ()
|
||||
|
||||
@@ -1477,10 +1477,3 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel, TFCausal
|
||||
+ layer_past_key_values[2:],
|
||||
)
|
||||
return (past[0], reordered_past)
|
||||
|
||||
def adjust_logits_during_generation(self, logits, cur_len, max_length):
|
||||
if cur_len == max_length - 1:
|
||||
vocab_range = tf.constant(range(self.config.vocab_size))
|
||||
return tf.where(vocab_range != self.config.eos_token_id, LARGE_NEGATIVE, logits)
|
||||
else:
|
||||
return logits
|
||||
|
||||
@@ -84,6 +84,9 @@ class BlenderbotSmallConfig(PretrainedConfig):
|
||||
Scale embeddings by diving by sqrt(d_model).
|
||||
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models)
|
||||
forced_eos_token_id (:obj:`int`, `optional`, defaults to 2):
|
||||
The id of the token to force as the last generated token when :obj:`max_length` is reached. Usually set to
|
||||
:obj:`eos_token_id`.
|
||||
|
||||
Example::
|
||||
|
||||
@@ -128,6 +131,7 @@ class BlenderbotSmallConfig(PretrainedConfig):
|
||||
pad_token_id=0,
|
||||
bos_token_id=1,
|
||||
eos_token_id=2,
|
||||
forced_eos_token_id=2,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(
|
||||
@@ -136,6 +140,7 @@ class BlenderbotSmallConfig(PretrainedConfig):
|
||||
eos_token_id=eos_token_id,
|
||||
is_encoder_decoder=is_encoder_decoder,
|
||||
decoder_start_token_id=decoder_start_token_id,
|
||||
forced_eos_token_id=forced_eos_token_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@@ -1310,16 +1310,6 @@ class BlenderbotSmallForConditionalGeneration(BlenderbotSmallPreTrainedModel):
|
||||
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
||||
}
|
||||
|
||||
def adjust_logits_during_generation(self, logits, cur_len, max_length):
|
||||
if cur_len == max_length - 1 and self.config.eos_token_id is not None:
|
||||
self._force_token_id_to_be_generated(logits, self.config.eos_token_id)
|
||||
return logits
|
||||
|
||||
@staticmethod
|
||||
def _force_token_id_to_be_generated(scores, token_id) -> None:
|
||||
"""force one of token_ids to be generated by setting prob of all other tokens to 0 (logprob=-float("inf"))"""
|
||||
scores[:, [x for x in range(scores.shape[1]) if x != token_id]] = -float("inf")
|
||||
|
||||
@staticmethod
|
||||
def _reorder_cache(past, beam_idx):
|
||||
reordered_past = ()
|
||||
|
||||
@@ -1452,10 +1452,3 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel
|
||||
+ layer_past_key_values[2:],
|
||||
)
|
||||
return (past[0], reordered_past)
|
||||
|
||||
def adjust_logits_during_generation(self, logits, cur_len, max_length):
|
||||
if cur_len == max_length - 1:
|
||||
vocab_range = tf.constant(range(self.config.vocab_size))
|
||||
return tf.where(vocab_range != self.config.eos_token_id, LARGE_NEGATIVE, logits)
|
||||
else:
|
||||
return logits
|
||||
|
||||
@@ -111,6 +111,9 @@ class FSMTConfig(PretrainedConfig):
|
||||
search when at least ``num_beams`` sentences are finished per batch or not.
|
||||
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models).
|
||||
forced_eos_token_id (:obj:`int`, `optional`, defaults to 2):
|
||||
The id of the token to force as the last generated token when :obj:`max_length` is reached. Usually set to
|
||||
:obj:`eos_token_id`.
|
||||
|
||||
Examples::
|
||||
|
||||
@@ -155,6 +158,7 @@ class FSMTConfig(PretrainedConfig):
|
||||
pad_token_id=1,
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
forced_eos_token_id=2,
|
||||
**common_kwargs
|
||||
):
|
||||
if "hidden_size" in common_kwargs:
|
||||
@@ -166,6 +170,7 @@ class FSMTConfig(PretrainedConfig):
|
||||
decoder_start_token_id=decoder_start_token_id,
|
||||
is_encoder_decoder=is_encoder_decoder,
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
forced_eos_token_id=forced_eos_token_id,
|
||||
**common_kwargs,
|
||||
)
|
||||
self.langs = langs
|
||||
|
||||
@@ -1210,23 +1210,6 @@ class FSMTForConditionalGeneration(PretrainedFSMTModel):
|
||||
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
|
||||
return shift_tokens_right(labels, self.config.pad_token_id)
|
||||
|
||||
def adjust_logits_during_generation(self, logits, cur_len, max_length):
|
||||
if cur_len == max_length - 1 and self.config.eos_token_id is not None:
|
||||
self._force_token_ids_generation(logits, self.config.eos_token_id)
|
||||
return logits
|
||||
|
||||
def _force_token_ids_generation(self, scores, token_ids) -> None:
|
||||
"""force one of token_ids to be generated by setting prob of all other tokens to 0"""
|
||||
if isinstance(token_ids, int):
|
||||
token_ids = [token_ids]
|
||||
all_but_token_ids_mask = torch.tensor(
|
||||
[x for x in range(self.config.tgt_vocab_size) if x not in token_ids],
|
||||
dtype=torch.long,
|
||||
device=next(self.parameters()).device,
|
||||
)
|
||||
assert len(scores.shape) == 2, "scores should be of rank 2 with shape: [batch_size, vocab_size]"
|
||||
scores[:, all_but_token_ids_mask] = -float("inf")
|
||||
|
||||
@staticmethod
|
||||
def _reorder_cache(past, beam_idx):
|
||||
reordered_past = []
|
||||
|
||||
@@ -84,6 +84,9 @@ class MarianConfig(PretrainedConfig):
|
||||
Scale embeddings by diving by sqrt(d_model).
|
||||
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models)
|
||||
forced_eos_token_id (:obj:`int`, `optional`, defaults to 0):
|
||||
The id of the token to force as the last generated token when :obj:`max_length` is reached. Usually set to
|
||||
:obj:`eos_token_id`.
|
||||
|
||||
Examples::
|
||||
|
||||
@@ -127,6 +130,7 @@ class MarianConfig(PretrainedConfig):
|
||||
gradient_checkpointing=False,
|
||||
pad_token_id=58100,
|
||||
eos_token_id=0,
|
||||
forced_eos_token_id=0,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(
|
||||
@@ -134,6 +138,7 @@ class MarianConfig(PretrainedConfig):
|
||||
eos_token_id=eos_token_id,
|
||||
is_encoder_decoder=is_encoder_decoder,
|
||||
decoder_start_token_id=decoder_start_token_id,
|
||||
forced_eos_token_id=forced_eos_token_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@@ -1325,15 +1325,8 @@ class MarianMTModel(MarianPreTrainedModel):
|
||||
|
||||
def adjust_logits_during_generation(self, logits, cur_len, max_length):
|
||||
logits[:, self.config.pad_token_id] = float("-inf") # never predict pad token.
|
||||
if cur_len == max_length - 1 and self.config.eos_token_id is not None:
|
||||
self._force_token_id_to_be_generated(logits, self.config.eos_token_id)
|
||||
return logits
|
||||
|
||||
@staticmethod
|
||||
def _force_token_id_to_be_generated(scores, token_id) -> None:
|
||||
"""force one of token_ids to be generated by setting prob of all other tokens to 0 (logprob=-float("inf"))"""
|
||||
scores[:, [x for x in range(scores.shape[1]) if x != token_id]] = -float("inf")
|
||||
|
||||
@staticmethod
|
||||
def _reorder_cache(past, beam_idx):
|
||||
reordered_past = ()
|
||||
|
||||
@@ -1470,10 +1470,17 @@ class TFMarianMTModel(TFMarianPreTrainedModel, TFCausalLanguageModelingLoss):
|
||||
)
|
||||
return (past[0], reordered_past)
|
||||
|
||||
def adjust_logits_during_generation(self, logits, cur_len, max_length):
|
||||
def adjust_logits_during_generation(
|
||||
self, logits, cur_len, max_length, forced_bos_token_id, forced_eos_token_id, **kwargs
|
||||
):
|
||||
"""Never predict pad_token_id. Predict </s> when max_length is reached."""
|
||||
vocab_range = tf.constant(range(self.config.vocab_size))
|
||||
logits = tf.where(vocab_range == self.config.pad_token_id, LARGE_NEGATIVE, logits)
|
||||
if cur_len == max_length - 1:
|
||||
logits = tf.where(vocab_range != self.config.eos_token_id, LARGE_NEGATIVE, logits)
|
||||
return logits
|
||||
if cur_len == 1 and forced_bos_token_id is not None:
|
||||
vocab_range = tf.constant(range(self.config.vocab_size))
|
||||
return tf.where(vocab_range != forced_bos_token_id, LARGE_NEGATIVE, logits)
|
||||
elif cur_len == max_length - 1 and forced_eos_token_id is not None:
|
||||
vocab_range = tf.constant(range(self.config.vocab_size))
|
||||
return tf.where(vocab_range != forced_eos_token_id, LARGE_NEGATIVE, logits)
|
||||
else:
|
||||
return logits
|
||||
|
||||
@@ -84,6 +84,9 @@ class MBartConfig(PretrainedConfig):
|
||||
Scale embeddings by diving by sqrt(d_model).
|
||||
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models)
|
||||
forced_eos_token_id (:obj:`int`, `optional`, defaults to 2):
|
||||
The id of the token to force as the last generated token when :obj:`max_length` is reached. Usually set to
|
||||
:obj:`eos_token_id`.
|
||||
|
||||
Example::
|
||||
|
||||
@@ -127,6 +130,7 @@ class MBartConfig(PretrainedConfig):
|
||||
pad_token_id=1,
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
forced_eos_token_id=2,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(
|
||||
@@ -134,6 +138,7 @@ class MBartConfig(PretrainedConfig):
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
is_encoder_decoder=is_encoder_decoder,
|
||||
forced_eos_token_id=forced_eos_token_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@@ -1344,16 +1344,6 @@ class MBartForConditionalGeneration(MBartPreTrainedModel):
|
||||
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
|
||||
return shift_tokens_right(labels, self.config.pad_token_id)
|
||||
|
||||
def adjust_logits_during_generation(self, logits, cur_len, max_length):
|
||||
if cur_len == max_length - 1 and self.config.eos_token_id is not None:
|
||||
self._force_token_id_to_be_generated(logits, self.config.eos_token_id)
|
||||
return logits
|
||||
|
||||
@staticmethod
|
||||
def _force_token_id_to_be_generated(scores, token_id) -> None:
|
||||
"""force one of token_ids to be generated by setting prob of all other tokens to 0 (logprob=-float("inf"))"""
|
||||
scores[:, [x for x in range(scores.shape[1]) if x != token_id]] = -float("inf")
|
||||
|
||||
@staticmethod
|
||||
def _reorder_cache(past, beam_idx):
|
||||
reordered_past = ()
|
||||
|
||||
@@ -1468,10 +1468,3 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel, TFCausalLanguageMo
|
||||
+ layer_past_key_values[2:],
|
||||
)
|
||||
return (past[0], reordered_past)
|
||||
|
||||
def adjust_logits_during_generation(self, logits, cur_len, max_length):
|
||||
if cur_len == max_length - 1:
|
||||
vocab_range = tf.constant(range(self.config.vocab_size))
|
||||
return tf.where(vocab_range != self.config.eos_token_id, LARGE_NEGATIVE, logits)
|
||||
else:
|
||||
return logits
|
||||
|
||||
@@ -84,6 +84,9 @@ class PegasusConfig(PretrainedConfig):
|
||||
Scale embeddings by diving by sqrt(d_model).
|
||||
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models)
|
||||
forced_eos_token_id (:obj:`int`, `optional`, defaults to 1):
|
||||
The id of the token to force as the last generated token when :obj:`max_length` is reached. Usually set to
|
||||
:obj:`eos_token_id`.
|
||||
|
||||
Example::
|
||||
|
||||
@@ -127,6 +130,7 @@ class PegasusConfig(PretrainedConfig):
|
||||
gradient_checkpointing=False,
|
||||
pad_token_id=0,
|
||||
eos_token_id=1,
|
||||
forced_eos_token_id=1,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(
|
||||
@@ -134,6 +138,7 @@ class PegasusConfig(PretrainedConfig):
|
||||
eos_token_id=eos_token_id,
|
||||
is_encoder_decoder=is_encoder_decoder,
|
||||
decoder_start_token_id=decoder_start_token_id,
|
||||
forced_eos_token_id=forced_eos_token_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@@ -1327,16 +1327,6 @@ class PegasusForConditionalGeneration(PegasusPreTrainedModel):
|
||||
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
|
||||
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
|
||||
|
||||
def adjust_logits_during_generation(self, logits, cur_len, max_length):
|
||||
if cur_len == max_length - 1 and self.config.eos_token_id is not None:
|
||||
self._force_token_id_to_be_generated(logits, self.config.eos_token_id)
|
||||
return logits
|
||||
|
||||
@staticmethod
|
||||
def _force_token_id_to_be_generated(scores, token_id) -> None:
|
||||
"""force one of token_ids to be generated by setting prob of all other tokens to 0 (logprob=-float("inf"))"""
|
||||
scores[:, [x for x in range(scores.shape[1]) if x != token_id]] = -float("inf")
|
||||
|
||||
@staticmethod
|
||||
def _reorder_cache(past, beam_idx):
|
||||
reordered_past = ()
|
||||
|
||||
@@ -1483,10 +1483,3 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel, TFCausalLangua
|
||||
+ layer_past_key_values[2:],
|
||||
)
|
||||
return (past[0], reordered_past)
|
||||
|
||||
def adjust_logits_during_generation(self, logits, cur_len, max_length):
|
||||
if cur_len == max_length - 1:
|
||||
vocab_range = tf.constant(range(self.config.vocab_size))
|
||||
return tf.where(vocab_range != self.config.eos_token_id, LARGE_NEGATIVE, logits)
|
||||
else:
|
||||
return logits
|
||||
|
||||
@@ -74,6 +74,9 @@ RAG_CONFIG_DOC = r"""
|
||||
:obj:`context_attention_mask` are returned. See returned tensors for more detail.
|
||||
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models).
|
||||
forced_eos_token_id (:obj:`int`, `optional`):
|
||||
The id of the token to force as the last generated token when :obj:`max_length` is reached. Usually set to
|
||||
:obj:`eos_token_id`.
|
||||
"""
|
||||
|
||||
|
||||
@@ -110,6 +113,7 @@ class RagConfig(PretrainedConfig):
|
||||
do_marginalize=False,
|
||||
output_retrieved=False,
|
||||
use_cache=True,
|
||||
forced_eos_token_id=None,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(
|
||||
@@ -117,6 +121,7 @@ class RagConfig(PretrainedConfig):
|
||||
pad_token_id=pad_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
decoder_start_token_id=decoder_start_token_id,
|
||||
forced_eos_token_id=forced_eos_token_id,
|
||||
is_encoder_decoder=is_encoder_decoder,
|
||||
prefix=prefix,
|
||||
vocab_size=vocab_size,
|
||||
@@ -161,6 +166,9 @@ class RagConfig(PretrainedConfig):
|
||||
|
||||
self.use_cache = use_cache
|
||||
|
||||
if self.forced_eos_token_id is None:
|
||||
self.forced_eos_token_id = getattr(self.generator, "forced_eos_token_id", None)
|
||||
|
||||
@classmethod
|
||||
def from_question_encoder_generator_configs(
|
||||
cls, question_encoder_config: PretrainedConfig, generator_config: PretrainedConfig, **kwargs
|
||||
|
||||
@@ -1089,9 +1089,6 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
||||
def set_retriever(self, retriever: RagRetriever):
|
||||
self.rag.retriever = retriever
|
||||
|
||||
def adjust_logits_during_generation(self, logits, cur_len, max_length):
|
||||
return self.rag.generator.adjust_logits_during_generation(logits, cur_len=cur_len, max_length=max_length)
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
decoder_input_ids,
|
||||
@@ -1313,6 +1310,8 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
||||
decoder_start_token_id=None,
|
||||
n_docs=None,
|
||||
prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]] = None,
|
||||
forced_bos_token_id: Optional[int] = None,
|
||||
forced_eos_token_id: Optional[int] = None,
|
||||
**model_kwargs
|
||||
):
|
||||
"""
|
||||
@@ -1403,6 +1402,12 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
||||
conditioned on the previously generated tokens :obj:`inputs_ids` and the batch ID :obj:`batch_id`. This
|
||||
argument is useful for constrained generation conditioned on the prefix, as described in
|
||||
`Autoregressive Entity Retrieval <https://arxiv.org/abs/2010.00904>`__.
|
||||
forced_bos_token_id (:obj:`int`, `optional`):
|
||||
The id of the token to force as the first generated token after the :obj:`decoder_start_token_id`.
|
||||
Useful for multilingual models like :doc:`mBART <../model_doc/mbart>` where the first generated token
|
||||
needs to be the target language token.
|
||||
forced_eos_token_id (:obj:`int`, `optional`):
|
||||
The id of the token to force as the last generated token when :obj:`max_length` is reached.
|
||||
|
||||
Return:
|
||||
:obj:`torch.LongTensor` of shape :obj:`(batch_size * num_return_sequences, sequence_length)`: The generated
|
||||
@@ -1498,7 +1503,10 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
||||
encoder_input_ids=context_input_ids,
|
||||
bad_words_ids=bad_words_ids,
|
||||
min_length=min_length,
|
||||
max_length=max_length,
|
||||
eos_token_id=eos_token_id,
|
||||
forced_bos_token_id=forced_bos_token_id,
|
||||
forced_eos_token_id=forced_eos_token_id,
|
||||
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
|
||||
num_beams=num_beams,
|
||||
num_beam_groups=num_beam_groups,
|
||||
|
||||
Reference in New Issue
Block a user