remove adjust_logits_during_generation method (#10087)
* add forced logits processors * delete adjust_logits method * add forced_eos_token_id argument in config * add tests for forced logits processors * update gen utils tests * add forced option to tf generate * remove adjust_logits method from tf models * update adjust_logits for marian * delete _force_token_id_to_be_generated method * style * import warnings * pass max_length to _get_logits_processor * set forced_eos_token_id to None * set forced attributes in conf utils * typo * fix rag generate * add forced_eos_token_id in rag config * remove force_bos_token_to_be_generated from BartConfig * remove _force_token_ids_generation from FSMT * nit * fix negative constant * apply suggestions from code review
This commit is contained in:
@@ -131,6 +131,11 @@ class PretrainedConfig(object):
|
|||||||
logits when used for generation
|
logits when used for generation
|
||||||
- **return_dict_in_generate** (:obj:`bool`, `optional`, defaults to :obj:`False`) -- Whether the model should
|
- **return_dict_in_generate** (:obj:`bool`, `optional`, defaults to :obj:`False`) -- Whether the model should
|
||||||
return a :class:`~transformers.file_utils.ModelOutput` instead of a :obj:`torch.LongTensor`
|
return a :class:`~transformers.file_utils.ModelOutput` instead of a :obj:`torch.LongTensor`
|
||||||
|
- **forced_bos_token_id** (:obj:`int`, `optional`) -- The id of the token to force as the first generated token
|
||||||
|
after the :obj:`decoder_start_token_id`. Useful for multilingual models like :doc:`mBART
|
||||||
|
<../model_doc/mbart>` where the first generated token needs to be the target language token.
|
||||||
|
- **forced_eos_token_id** (:obj:`int`, `optional`) -- The id of the token to force as the last generated token
|
||||||
|
when :obj:`max_length` is reached.
|
||||||
|
|
||||||
|
|
||||||
Parameters for fine-tuning tasks
|
Parameters for fine-tuning tasks
|
||||||
@@ -214,6 +219,8 @@ class PretrainedConfig(object):
|
|||||||
self.chunk_size_feed_forward = kwargs.pop("chunk_size_feed_forward", 0)
|
self.chunk_size_feed_forward = kwargs.pop("chunk_size_feed_forward", 0)
|
||||||
self.output_scores = kwargs.pop("output_scores", False)
|
self.output_scores = kwargs.pop("output_scores", False)
|
||||||
self.return_dict_in_generate = kwargs.pop("return_dict_in_generate", False)
|
self.return_dict_in_generate = kwargs.pop("return_dict_in_generate", False)
|
||||||
|
self.forced_bos_token_id = kwargs.pop("forced_bos_token_id", None)
|
||||||
|
self.forced_eos_token_id = kwargs.pop("forced_eos_token_id", None)
|
||||||
|
|
||||||
# Fine-tuning task arguments
|
# Fine-tuning task arguments
|
||||||
self.architectures = kwargs.pop("architectures", None)
|
self.architectures = kwargs.pop("architectures", None)
|
||||||
|
|||||||
@@ -520,3 +520,49 @@ class HammingDiversityLogitsProcessor(LogitsProcessor):
|
|||||||
scores[batch_idx * group_size : (batch_idx + 1) * group_size] -= self._diversity_penalty * token_frequency
|
scores[batch_idx * group_size : (batch_idx + 1) * group_size] -= self._diversity_penalty * token_frequency
|
||||||
|
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
|
|
||||||
|
class ForcedBOSTokenLogitsProcessor(LogitsProcessor):
|
||||||
|
r"""
|
||||||
|
:class:`~transformers.LogitsProcessor` that enforces the specified token as the first generated token.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
bos_token_id (:obj:`int`):
|
||||||
|
The id of the token to force as the first generated token.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, bos_token_id: int):
|
||||||
|
self.bos_token_id = bos_token_id
|
||||||
|
|
||||||
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
|
cur_len = input_ids.shape[-1]
|
||||||
|
if cur_len == 1:
|
||||||
|
num_tokens = scores.shape[1]
|
||||||
|
scores[:, [i for i in range(num_tokens) if i != self.bos_token_id]] = -float("inf")
|
||||||
|
scores[:, self.bos_token_id] = 0
|
||||||
|
return scores
|
||||||
|
|
||||||
|
|
||||||
|
class ForcedEOSTokenLogitsProcessor(LogitsProcessor):
|
||||||
|
r"""
|
||||||
|
:class:`~transformers.LogitsProcessor` that enforces the specified token as the last generated token when
|
||||||
|
:obj:`max_length` is reached.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_length (:obj:`int`):
|
||||||
|
The maximum length of the sequence to be generated.
|
||||||
|
eos_token_id (:obj:`int`):
|
||||||
|
The id of the token to force as the last generated token when :obj:`max_length` is reached.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, max_length: int, eos_token_id: int):
|
||||||
|
self.max_length = max_length
|
||||||
|
self.eos_token_id = eos_token_id
|
||||||
|
|
||||||
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
|
cur_len = input_ids.shape[-1]
|
||||||
|
if cur_len == self.max_length - 1:
|
||||||
|
num_tokens = scores.shape[1]
|
||||||
|
scores[:, [i for i in range(num_tokens) if i != self.eos_token_id]] = -float("inf")
|
||||||
|
scores[:, self.eos_token_id] = 0
|
||||||
|
return scores
|
||||||
|
|||||||
@@ -67,6 +67,8 @@ class TFGenerationMixin:
|
|||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
decoder_start_token_id=None,
|
decoder_start_token_id=None,
|
||||||
use_cache=None,
|
use_cache=None,
|
||||||
|
forced_bos_token_id=None,
|
||||||
|
forced_eos_token_id=None,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Generates sequences for models with a language modeling head. The method currently supports greedy decoding,
|
Generates sequences for models with a language modeling head. The method currently supports greedy decoding,
|
||||||
@@ -137,6 +139,12 @@ class TFGenerationMixin:
|
|||||||
use_cache: (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
use_cache: (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||||
Whether or not the model should use the past last key/values attentions (if applicable to the model) to
|
Whether or not the model should use the past last key/values attentions (if applicable to the model) to
|
||||||
speed up decoding.
|
speed up decoding.
|
||||||
|
forced_bos_token_id (:obj:`int`, `optional`):
|
||||||
|
The id of the token to force as the first generated token after the :obj:`decoder_start_token_id`.
|
||||||
|
Useful for multilingual models like :doc:`mBART <../model_doc/mbart>` where the first generated token
|
||||||
|
needs to be the target language token.
|
||||||
|
forced_eos_token_id (:obj:`int`, `optional`):
|
||||||
|
The id of the token to force as the last generated token when :obj:`max_length` is reached.
|
||||||
model_specific_kwargs:
|
model_specific_kwargs:
|
||||||
Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model.
|
Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model.
|
||||||
|
|
||||||
@@ -214,6 +222,12 @@ class TFGenerationMixin:
|
|||||||
decoder_start_token_id = (
|
decoder_start_token_id = (
|
||||||
decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id
|
decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id
|
||||||
)
|
)
|
||||||
|
forced_bos_token_id = (
|
||||||
|
forced_bos_token_id if forced_bos_token_id is not None else self.config.forced_bos_token_id
|
||||||
|
)
|
||||||
|
forced_eos_token_id = (
|
||||||
|
forced_eos_token_id if forced_eos_token_id is not None else self.config.forced_eos_token_id
|
||||||
|
)
|
||||||
|
|
||||||
if input_ids is not None:
|
if input_ids is not None:
|
||||||
batch_size = shape_list(input_ids)[0] # overridden by the input batch_size
|
batch_size = shape_list(input_ids)[0] # overridden by the input batch_size
|
||||||
@@ -380,6 +394,8 @@ class TFGenerationMixin:
|
|||||||
encoder_outputs=encoder_outputs,
|
encoder_outputs=encoder_outputs,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
|
forced_bos_token_id=forced_bos_token_id,
|
||||||
|
forced_eos_token_id=forced_eos_token_id,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
output = self._generate_no_beam_search(
|
output = self._generate_no_beam_search(
|
||||||
@@ -591,6 +607,8 @@ class TFGenerationMixin:
|
|||||||
encoder_outputs,
|
encoder_outputs,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
use_cache,
|
use_cache,
|
||||||
|
forced_bos_token_id,
|
||||||
|
forced_eos_token_id,
|
||||||
):
|
):
|
||||||
"""Generate sequences for each example with beam search."""
|
"""Generate sequences for each example with beam search."""
|
||||||
|
|
||||||
@@ -641,7 +659,11 @@ class TFGenerationMixin:
|
|||||||
|
|
||||||
if self.config.is_encoder_decoder and do_sample is False:
|
if self.config.is_encoder_decoder and do_sample is False:
|
||||||
next_token_logits = self.adjust_logits_during_generation(
|
next_token_logits = self.adjust_logits_during_generation(
|
||||||
next_token_logits, cur_len=cur_len, max_length=max_length
|
next_token_logits,
|
||||||
|
cur_len=cur_len,
|
||||||
|
max_length=max_length,
|
||||||
|
forced_bos_token_id=forced_bos_token_id,
|
||||||
|
forced_eos_token_id=forced_eos_token_id,
|
||||||
)
|
)
|
||||||
# calculate log softmax score
|
# calculate log softmax score
|
||||||
scores = tf.nn.log_softmax(next_token_logits, axis=-1) # (batch_size * num_beams, vocab_size)
|
scores = tf.nn.log_softmax(next_token_logits, axis=-1) # (batch_size * num_beams, vocab_size)
|
||||||
@@ -893,11 +915,20 @@ class TFGenerationMixin:
|
|||||||
def _reorder_cache(past, beam_idx):
|
def _reorder_cache(past, beam_idx):
|
||||||
return tuple(tf.gather(layer_past, beam_idx, axis=1) for layer_past in past)
|
return tuple(tf.gather(layer_past, beam_idx, axis=1) for layer_past in past)
|
||||||
|
|
||||||
def adjust_logits_during_generation(self, logits, **kwargs):
|
def adjust_logits_during_generation(
|
||||||
|
self, logits, cur_len, max_length, forced_bos_token_id, forced_eos_token_id, **kwargs
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Implement in subclasses of :class:`~transformers.PreTrainedModel` for custom behavior to adjust the logits in
|
Implement in subclasses of :class:`~transformers.PreTrainedModel` for custom behavior to adjust the logits in
|
||||||
the generate method.
|
the generate method.
|
||||||
"""
|
"""
|
||||||
|
if cur_len == 1 and forced_bos_token_id is not None:
|
||||||
|
vocab_range = tf.constant(range(self.config.vocab_size))
|
||||||
|
return tf.where(vocab_range != forced_bos_token_id, -1e8, logits)
|
||||||
|
elif cur_len == max_length - 1 and forced_eos_token_id is not None:
|
||||||
|
vocab_range = tf.constant(range(self.config.vocab_size))
|
||||||
|
return tf.where(vocab_range != forced_eos_token_id, -1e8, logits)
|
||||||
|
else:
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -24,6 +24,8 @@ from .file_utils import ModelOutput
|
|||||||
from .generation_beam_search import BeamScorer, BeamSearchScorer
|
from .generation_beam_search import BeamScorer, BeamSearchScorer
|
||||||
from .generation_logits_process import (
|
from .generation_logits_process import (
|
||||||
EncoderNoRepeatNGramLogitsProcessor,
|
EncoderNoRepeatNGramLogitsProcessor,
|
||||||
|
ForcedBOSTokenLogitsProcessor,
|
||||||
|
ForcedEOSTokenLogitsProcessor,
|
||||||
HammingDiversityLogitsProcessor,
|
HammingDiversityLogitsProcessor,
|
||||||
LogitsProcessorList,
|
LogitsProcessorList,
|
||||||
MinLengthLogitsProcessor,
|
MinLengthLogitsProcessor,
|
||||||
@@ -542,7 +544,10 @@ class GenerationMixin:
|
|||||||
encoder_input_ids: torch.LongTensor,
|
encoder_input_ids: torch.LongTensor,
|
||||||
bad_words_ids: List[List[int]],
|
bad_words_ids: List[List[int]],
|
||||||
min_length: int,
|
min_length: int,
|
||||||
|
max_length: int,
|
||||||
eos_token_id: int,
|
eos_token_id: int,
|
||||||
|
forced_bos_token_id: int,
|
||||||
|
forced_eos_token_id: int,
|
||||||
prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]],
|
prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]],
|
||||||
num_beams: int,
|
num_beams: int,
|
||||||
num_beam_groups: int,
|
num_beam_groups: int,
|
||||||
@@ -567,6 +572,12 @@ class GenerationMixin:
|
|||||||
min_length = min_length if min_length is not None else self.config.min_length
|
min_length = min_length if min_length is not None else self.config.min_length
|
||||||
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
||||||
diversity_penalty = diversity_penalty if diversity_penalty is not None else self.config.diversity_penalty
|
diversity_penalty = diversity_penalty if diversity_penalty is not None else self.config.diversity_penalty
|
||||||
|
forced_bos_token_id = (
|
||||||
|
forced_bos_token_id if forced_bos_token_id is not None else self.config.forced_bos_token_id
|
||||||
|
)
|
||||||
|
forced_eos_token_id = (
|
||||||
|
forced_eos_token_id if forced_eos_token_id is not None else self.config.forced_eos_token_id
|
||||||
|
)
|
||||||
# instantiate processors list
|
# instantiate processors list
|
||||||
processors = LogitsProcessorList()
|
processors = LogitsProcessorList()
|
||||||
|
|
||||||
@@ -595,6 +606,10 @@ class GenerationMixin:
|
|||||||
processors.append(MinLengthLogitsProcessor(min_length, eos_token_id))
|
processors.append(MinLengthLogitsProcessor(min_length, eos_token_id))
|
||||||
if prefix_allowed_tokens_fn is not None:
|
if prefix_allowed_tokens_fn is not None:
|
||||||
processors.append(PrefixConstrainedLogitsProcessor(prefix_allowed_tokens_fn, num_beams))
|
processors.append(PrefixConstrainedLogitsProcessor(prefix_allowed_tokens_fn, num_beams))
|
||||||
|
if forced_bos_token_id is not None:
|
||||||
|
processors.append(ForcedBOSTokenLogitsProcessor(forced_bos_token_id))
|
||||||
|
if forced_eos_token_id is not None:
|
||||||
|
processors.append(ForcedEOSTokenLogitsProcessor(max_length, forced_eos_token_id))
|
||||||
return processors
|
return processors
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@@ -627,6 +642,8 @@ class GenerationMixin:
|
|||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
output_scores: Optional[bool] = None,
|
output_scores: Optional[bool] = None,
|
||||||
return_dict_in_generate: Optional[bool] = None,
|
return_dict_in_generate: Optional[bool] = None,
|
||||||
|
forced_bos_token_id: Optional[int] = None,
|
||||||
|
forced_eos_token_id: Optional[int] = None,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
) -> Union[GreedySearchOutput, SampleOutput, BeamSearchOutput, BeamSampleOutput, torch.LongTensor]:
|
) -> Union[GreedySearchOutput, SampleOutput, BeamSearchOutput, BeamSampleOutput, torch.LongTensor]:
|
||||||
r"""
|
r"""
|
||||||
@@ -720,6 +737,12 @@ class GenerationMixin:
|
|||||||
Whether or not to return the prediction scores. See ``scores`` under returned tensors for more details.
|
Whether or not to return the prediction scores. See ``scores`` under returned tensors for more details.
|
||||||
return_dict_in_generate (:obj:`bool`, `optional`, defaults to `False`):
|
return_dict_in_generate (:obj:`bool`, `optional`, defaults to `False`):
|
||||||
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
|
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
|
||||||
|
forced_bos_token_id (:obj:`int`, `optional`):
|
||||||
|
The id of the token to force as the first generated token after the :obj:`decoder_start_token_id`.
|
||||||
|
Useful for multilingual models like :doc:`mBART <../model_doc/mbart>` where the first generated token
|
||||||
|
needs to be the target language token.
|
||||||
|
forced_eos_token_id (:obj:`int`, `optional`):
|
||||||
|
The id of the token to force as the last generated token when :obj:`max_length` is reached.
|
||||||
|
|
||||||
model_kwargs:
|
model_kwargs:
|
||||||
Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model. If the
|
Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model. If the
|
||||||
@@ -888,7 +911,10 @@ class GenerationMixin:
|
|||||||
encoder_input_ids=encoder_input_ids,
|
encoder_input_ids=encoder_input_ids,
|
||||||
bad_words_ids=bad_words_ids,
|
bad_words_ids=bad_words_ids,
|
||||||
min_length=min_length,
|
min_length=min_length,
|
||||||
|
max_length=max_length,
|
||||||
eos_token_id=eos_token_id,
|
eos_token_id=eos_token_id,
|
||||||
|
forced_bos_token_id=forced_bos_token_id,
|
||||||
|
forced_eos_token_id=forced_eos_token_id,
|
||||||
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
|
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
|
||||||
num_beams=num_beams,
|
num_beams=num_beams,
|
||||||
num_beam_groups=num_beam_groups,
|
num_beam_groups=num_beam_groups,
|
||||||
@@ -1611,7 +1637,8 @@ class GenerationMixin:
|
|||||||
)
|
)
|
||||||
next_token_logits = outputs.logits[:, -1, :]
|
next_token_logits = outputs.logits[:, -1, :]
|
||||||
|
|
||||||
# adjust tokens for Bart, *e.g.*
|
# hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
|
||||||
|
# cannot be generated both before and after the `F.log_softmax` operation.
|
||||||
next_token_logits = self.adjust_logits_during_generation(
|
next_token_logits = self.adjust_logits_during_generation(
|
||||||
next_token_logits, cur_len=cur_len, max_length=max_length
|
next_token_logits, cur_len=cur_len, max_length=max_length
|
||||||
)
|
)
|
||||||
@@ -1866,7 +1893,8 @@ class GenerationMixin:
|
|||||||
)
|
)
|
||||||
next_token_logits = outputs.logits[:, -1, :]
|
next_token_logits = outputs.logits[:, -1, :]
|
||||||
|
|
||||||
# adjust token scores (a no-op by default)
|
# hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
|
||||||
|
# cannot be generated both before and after the `F.log_softmax` operation.
|
||||||
next_token_logits = self.adjust_logits_during_generation(
|
next_token_logits = self.adjust_logits_during_generation(
|
||||||
next_token_logits, cur_len=cur_len, max_length=max_length
|
next_token_logits, cur_len=cur_len, max_length=max_length
|
||||||
)
|
)
|
||||||
@@ -2150,7 +2178,8 @@ class GenerationMixin:
|
|||||||
# select outputs of beams of current group only
|
# select outputs of beams of current group only
|
||||||
next_token_logits = outputs.logits[batch_group_indices, -1, :]
|
next_token_logits = outputs.logits[batch_group_indices, -1, :]
|
||||||
|
|
||||||
# adjust tokens for Bart, *e.g.*
|
# hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
|
||||||
|
# cannot be generated both before and after the `F.log_softmax` operation.
|
||||||
next_token_logits = self.adjust_logits_during_generation(
|
next_token_logits = self.adjust_logits_during_generation(
|
||||||
next_token_logits, cur_len=cur_len, max_length=max_length
|
next_token_logits, cur_len=cur_len, max_length=max_length
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -13,6 +13,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" BART model configuration """
|
""" BART model configuration """
|
||||||
|
import warnings
|
||||||
|
|
||||||
from ...configuration_utils import PretrainedConfig
|
from ...configuration_utils import PretrainedConfig
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
@@ -72,9 +73,6 @@ class BartConfig(PretrainedConfig):
|
|||||||
just in case (e.g., 512 or 1024 or 2048).
|
just in case (e.g., 512 or 1024 or 2048).
|
||||||
init_std (:obj:`float`, `optional`, defaults to 0.02):
|
init_std (:obj:`float`, `optional`, defaults to 0.02):
|
||||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||||
force_bos_token_to_be_generated (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
|
||||||
Whether or not to force BOS token to be generated at step 1 (after ``decoder_start_token_id``), only
|
|
||||||
:obj:`True` for `bart-large-cnn`.
|
|
||||||
encoder_layerdrop: (:obj:`float`, `optional`, defaults to 0.0):
|
encoder_layerdrop: (:obj:`float`, `optional`, defaults to 0.0):
|
||||||
The LayerDrop probability for the encoder. See the `LayerDrop paper <see
|
The LayerDrop probability for the encoder. See the `LayerDrop paper <see
|
||||||
https://arxiv.org/abs/1909.11556>`__ for more details.
|
https://arxiv.org/abs/1909.11556>`__ for more details.
|
||||||
@@ -89,6 +87,9 @@ class BartConfig(PretrainedConfig):
|
|||||||
Whether or not the model should return the last key/values attentions (not used by all models).
|
Whether or not the model should return the last key/values attentions (not used by all models).
|
||||||
num_labels: (:obj:`int`, `optional`, defaults to 3):
|
num_labels: (:obj:`int`, `optional`, defaults to 3):
|
||||||
The number of labels to use in :class:`~transformers.BartForSequenceClassification`.
|
The number of labels to use in :class:`~transformers.BartForSequenceClassification`.
|
||||||
|
forced_eos_token_id (:obj:`int`, `optional`, defaults to 2):
|
||||||
|
The id of the token to force as the last generated token when :obj:`max_length` is reached. Usually set to
|
||||||
|
:obj:`eos_token_id`.
|
||||||
|
|
||||||
Example::
|
Example::
|
||||||
|
|
||||||
@@ -127,7 +128,6 @@ class BartConfig(PretrainedConfig):
|
|||||||
classifier_dropout=0.0,
|
classifier_dropout=0.0,
|
||||||
scale_embedding=False,
|
scale_embedding=False,
|
||||||
gradient_checkpointing=False,
|
gradient_checkpointing=False,
|
||||||
force_bos_token_to_be_generated=False,
|
|
||||||
use_cache=True,
|
use_cache=True,
|
||||||
num_labels=3,
|
num_labels=3,
|
||||||
pad_token_id=1,
|
pad_token_id=1,
|
||||||
@@ -135,6 +135,7 @@ class BartConfig(PretrainedConfig):
|
|||||||
eos_token_id=2,
|
eos_token_id=2,
|
||||||
is_encoder_decoder=True,
|
is_encoder_decoder=True,
|
||||||
decoder_start_token_id=2,
|
decoder_start_token_id=2,
|
||||||
|
forced_eos_token_id=2,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
@@ -144,6 +145,7 @@ class BartConfig(PretrainedConfig):
|
|||||||
eos_token_id=eos_token_id,
|
eos_token_id=eos_token_id,
|
||||||
is_encoder_decoder=is_encoder_decoder,
|
is_encoder_decoder=is_encoder_decoder,
|
||||||
decoder_start_token_id=decoder_start_token_id,
|
decoder_start_token_id=decoder_start_token_id,
|
||||||
|
forced_eos_token_id=forced_eos_token_id,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -168,7 +170,14 @@ class BartConfig(PretrainedConfig):
|
|||||||
self.num_hidden_layers = encoder_layers
|
self.num_hidden_layers = encoder_layers
|
||||||
self.gradient_checkpointing = gradient_checkpointing
|
self.gradient_checkpointing = gradient_checkpointing
|
||||||
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
|
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
|
||||||
self.force_bos_token_to_be_generated = force_bos_token_to_be_generated # only relevant for CNN
|
|
||||||
|
# ensure backward compatibilty for BART CNN models
|
||||||
|
if self.forced_bos_token_id is None and kwargs.get("force_bos_token_to_be_generated", False):
|
||||||
|
self.forced_bos_token_id = self.bos_token_id
|
||||||
|
warnings.warn(
|
||||||
|
f"Please make sure the config includes `forced_bos_token_id={self.bos_token_id}` in future versions."
|
||||||
|
"The config can simply be saved and uploaded again to be fixed."
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_attention_heads(self) -> int:
|
def num_attention_heads(self) -> int:
|
||||||
|
|||||||
@@ -1344,18 +1344,6 @@ class BartForConditionalGeneration(BartPretrainedModel):
|
|||||||
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
|
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
|
||||||
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
|
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
|
||||||
|
|
||||||
def adjust_logits_during_generation(self, logits, cur_len, max_length):
|
|
||||||
if cur_len == 1 and self.config.force_bos_token_to_be_generated:
|
|
||||||
self._force_token_id_to_be_generated(logits, self.config.bos_token_id)
|
|
||||||
elif cur_len == max_length - 1 and self.config.eos_token_id is not None:
|
|
||||||
self._force_token_id_to_be_generated(logits, self.config.eos_token_id)
|
|
||||||
return logits
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _force_token_id_to_be_generated(scores, token_id) -> None:
|
|
||||||
"""force one of token_ids to be generated by setting prob of all other tokens to 0 (logprob=-float("inf"))"""
|
|
||||||
scores[:, [x for x in range(scores.shape[1]) if x != token_id]] = -float("inf")
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _reorder_cache(past, beam_idx):
|
def _reorder_cache(past, beam_idx):
|
||||||
reordered_past = ()
|
reordered_past = ()
|
||||||
|
|||||||
@@ -1444,13 +1444,3 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode
|
|||||||
+ layer_past_key_values[2:],
|
+ layer_past_key_values[2:],
|
||||||
)
|
)
|
||||||
return (past[0], reordered_past)
|
return (past[0], reordered_past)
|
||||||
|
|
||||||
def adjust_logits_during_generation(self, logits, cur_len, max_length):
|
|
||||||
if cur_len == 1 and self.config.force_bos_token_to_be_generated:
|
|
||||||
vocab_range = tf.constant(range(self.config.vocab_size))
|
|
||||||
return tf.where(vocab_range != self.config.bos_token_id, LARGE_NEGATIVE, logits)
|
|
||||||
elif cur_len == max_length - 1:
|
|
||||||
vocab_range = tf.constant(range(self.config.vocab_size))
|
|
||||||
return tf.where(vocab_range != self.config.eos_token_id, LARGE_NEGATIVE, logits)
|
|
||||||
else:
|
|
||||||
return logits
|
|
||||||
|
|||||||
@@ -84,6 +84,9 @@ class BlenderbotConfig(PretrainedConfig):
|
|||||||
Scale embeddings by diving by sqrt(d_model).
|
Scale embeddings by diving by sqrt(d_model).
|
||||||
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||||
Whether or not the model should return the last key/values attentions (not used by all models)
|
Whether or not the model should return the last key/values attentions (not used by all models)
|
||||||
|
forced_eos_token_id (:obj:`int`, `optional`, defaults to 2):
|
||||||
|
The id of the token to force as the last generated token when :obj:`max_length` is reached. Usually set to
|
||||||
|
:obj:`eos_token_id`.
|
||||||
|
|
||||||
Example::
|
Example::
|
||||||
|
|
||||||
@@ -129,6 +132,7 @@ class BlenderbotConfig(PretrainedConfig):
|
|||||||
bos_token_id=1,
|
bos_token_id=1,
|
||||||
eos_token_id=2,
|
eos_token_id=2,
|
||||||
encoder_no_repeat_ngram_size=3,
|
encoder_no_repeat_ngram_size=3,
|
||||||
|
forced_eos_token_id=2,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
@@ -138,6 +142,7 @@ class BlenderbotConfig(PretrainedConfig):
|
|||||||
is_encoder_decoder=is_encoder_decoder,
|
is_encoder_decoder=is_encoder_decoder,
|
||||||
decoder_start_token_id=decoder_start_token_id,
|
decoder_start_token_id=decoder_start_token_id,
|
||||||
encoder_no_repeat_ngram_size=encoder_no_repeat_ngram_size,
|
encoder_no_repeat_ngram_size=encoder_no_repeat_ngram_size,
|
||||||
|
forced_eos_token_id=forced_eos_token_id,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -1335,16 +1335,6 @@ class BlenderbotForConditionalGeneration(BlenderbotPreTrainedModel):
|
|||||||
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
||||||
}
|
}
|
||||||
|
|
||||||
def adjust_logits_during_generation(self, logits, cur_len, max_length):
|
|
||||||
if cur_len == max_length - 1 and self.config.eos_token_id is not None:
|
|
||||||
self._force_token_id_to_be_generated(logits, self.config.eos_token_id)
|
|
||||||
return logits
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _force_token_id_to_be_generated(scores, token_id) -> None:
|
|
||||||
"""force one of token_ids to be generated by setting prob of all other tokens to 0 (logprob=-float("inf"))"""
|
|
||||||
scores[:, [x for x in range(scores.shape[1]) if x != token_id]] = -float("inf")
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _reorder_cache(past, beam_idx):
|
def _reorder_cache(past, beam_idx):
|
||||||
reordered_past = ()
|
reordered_past = ()
|
||||||
|
|||||||
@@ -1477,10 +1477,3 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel, TFCausal
|
|||||||
+ layer_past_key_values[2:],
|
+ layer_past_key_values[2:],
|
||||||
)
|
)
|
||||||
return (past[0], reordered_past)
|
return (past[0], reordered_past)
|
||||||
|
|
||||||
def adjust_logits_during_generation(self, logits, cur_len, max_length):
|
|
||||||
if cur_len == max_length - 1:
|
|
||||||
vocab_range = tf.constant(range(self.config.vocab_size))
|
|
||||||
return tf.where(vocab_range != self.config.eos_token_id, LARGE_NEGATIVE, logits)
|
|
||||||
else:
|
|
||||||
return logits
|
|
||||||
|
|||||||
@@ -84,6 +84,9 @@ class BlenderbotSmallConfig(PretrainedConfig):
|
|||||||
Scale embeddings by diving by sqrt(d_model).
|
Scale embeddings by diving by sqrt(d_model).
|
||||||
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||||
Whether or not the model should return the last key/values attentions (not used by all models)
|
Whether or not the model should return the last key/values attentions (not used by all models)
|
||||||
|
forced_eos_token_id (:obj:`int`, `optional`, defaults to 2):
|
||||||
|
The id of the token to force as the last generated token when :obj:`max_length` is reached. Usually set to
|
||||||
|
:obj:`eos_token_id`.
|
||||||
|
|
||||||
Example::
|
Example::
|
||||||
|
|
||||||
@@ -128,6 +131,7 @@ class BlenderbotSmallConfig(PretrainedConfig):
|
|||||||
pad_token_id=0,
|
pad_token_id=0,
|
||||||
bos_token_id=1,
|
bos_token_id=1,
|
||||||
eos_token_id=2,
|
eos_token_id=2,
|
||||||
|
forced_eos_token_id=2,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
@@ -136,6 +140,7 @@ class BlenderbotSmallConfig(PretrainedConfig):
|
|||||||
eos_token_id=eos_token_id,
|
eos_token_id=eos_token_id,
|
||||||
is_encoder_decoder=is_encoder_decoder,
|
is_encoder_decoder=is_encoder_decoder,
|
||||||
decoder_start_token_id=decoder_start_token_id,
|
decoder_start_token_id=decoder_start_token_id,
|
||||||
|
forced_eos_token_id=forced_eos_token_id,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -1310,16 +1310,6 @@ class BlenderbotSmallForConditionalGeneration(BlenderbotSmallPreTrainedModel):
|
|||||||
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
||||||
}
|
}
|
||||||
|
|
||||||
def adjust_logits_during_generation(self, logits, cur_len, max_length):
|
|
||||||
if cur_len == max_length - 1 and self.config.eos_token_id is not None:
|
|
||||||
self._force_token_id_to_be_generated(logits, self.config.eos_token_id)
|
|
||||||
return logits
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _force_token_id_to_be_generated(scores, token_id) -> None:
|
|
||||||
"""force one of token_ids to be generated by setting prob of all other tokens to 0 (logprob=-float("inf"))"""
|
|
||||||
scores[:, [x for x in range(scores.shape[1]) if x != token_id]] = -float("inf")
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _reorder_cache(past, beam_idx):
|
def _reorder_cache(past, beam_idx):
|
||||||
reordered_past = ()
|
reordered_past = ()
|
||||||
|
|||||||
@@ -1452,10 +1452,3 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel
|
|||||||
+ layer_past_key_values[2:],
|
+ layer_past_key_values[2:],
|
||||||
)
|
)
|
||||||
return (past[0], reordered_past)
|
return (past[0], reordered_past)
|
||||||
|
|
||||||
def adjust_logits_during_generation(self, logits, cur_len, max_length):
|
|
||||||
if cur_len == max_length - 1:
|
|
||||||
vocab_range = tf.constant(range(self.config.vocab_size))
|
|
||||||
return tf.where(vocab_range != self.config.eos_token_id, LARGE_NEGATIVE, logits)
|
|
||||||
else:
|
|
||||||
return logits
|
|
||||||
|
|||||||
@@ -111,6 +111,9 @@ class FSMTConfig(PretrainedConfig):
|
|||||||
search when at least ``num_beams`` sentences are finished per batch or not.
|
search when at least ``num_beams`` sentences are finished per batch or not.
|
||||||
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||||
Whether or not the model should return the last key/values attentions (not used by all models).
|
Whether or not the model should return the last key/values attentions (not used by all models).
|
||||||
|
forced_eos_token_id (:obj:`int`, `optional`, defaults to 2):
|
||||||
|
The id of the token to force as the last generated token when :obj:`max_length` is reached. Usually set to
|
||||||
|
:obj:`eos_token_id`.
|
||||||
|
|
||||||
Examples::
|
Examples::
|
||||||
|
|
||||||
@@ -155,6 +158,7 @@ class FSMTConfig(PretrainedConfig):
|
|||||||
pad_token_id=1,
|
pad_token_id=1,
|
||||||
bos_token_id=0,
|
bos_token_id=0,
|
||||||
eos_token_id=2,
|
eos_token_id=2,
|
||||||
|
forced_eos_token_id=2,
|
||||||
**common_kwargs
|
**common_kwargs
|
||||||
):
|
):
|
||||||
if "hidden_size" in common_kwargs:
|
if "hidden_size" in common_kwargs:
|
||||||
@@ -166,6 +170,7 @@ class FSMTConfig(PretrainedConfig):
|
|||||||
decoder_start_token_id=decoder_start_token_id,
|
decoder_start_token_id=decoder_start_token_id,
|
||||||
is_encoder_decoder=is_encoder_decoder,
|
is_encoder_decoder=is_encoder_decoder,
|
||||||
tie_word_embeddings=tie_word_embeddings,
|
tie_word_embeddings=tie_word_embeddings,
|
||||||
|
forced_eos_token_id=forced_eos_token_id,
|
||||||
**common_kwargs,
|
**common_kwargs,
|
||||||
)
|
)
|
||||||
self.langs = langs
|
self.langs = langs
|
||||||
|
|||||||
@@ -1210,23 +1210,6 @@ class FSMTForConditionalGeneration(PretrainedFSMTModel):
|
|||||||
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
|
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
|
||||||
return shift_tokens_right(labels, self.config.pad_token_id)
|
return shift_tokens_right(labels, self.config.pad_token_id)
|
||||||
|
|
||||||
def adjust_logits_during_generation(self, logits, cur_len, max_length):
|
|
||||||
if cur_len == max_length - 1 and self.config.eos_token_id is not None:
|
|
||||||
self._force_token_ids_generation(logits, self.config.eos_token_id)
|
|
||||||
return logits
|
|
||||||
|
|
||||||
def _force_token_ids_generation(self, scores, token_ids) -> None:
|
|
||||||
"""force one of token_ids to be generated by setting prob of all other tokens to 0"""
|
|
||||||
if isinstance(token_ids, int):
|
|
||||||
token_ids = [token_ids]
|
|
||||||
all_but_token_ids_mask = torch.tensor(
|
|
||||||
[x for x in range(self.config.tgt_vocab_size) if x not in token_ids],
|
|
||||||
dtype=torch.long,
|
|
||||||
device=next(self.parameters()).device,
|
|
||||||
)
|
|
||||||
assert len(scores.shape) == 2, "scores should be of rank 2 with shape: [batch_size, vocab_size]"
|
|
||||||
scores[:, all_but_token_ids_mask] = -float("inf")
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _reorder_cache(past, beam_idx):
|
def _reorder_cache(past, beam_idx):
|
||||||
reordered_past = []
|
reordered_past = []
|
||||||
|
|||||||
@@ -84,6 +84,9 @@ class MarianConfig(PretrainedConfig):
|
|||||||
Scale embeddings by diving by sqrt(d_model).
|
Scale embeddings by diving by sqrt(d_model).
|
||||||
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||||
Whether or not the model should return the last key/values attentions (not used by all models)
|
Whether or not the model should return the last key/values attentions (not used by all models)
|
||||||
|
forced_eos_token_id (:obj:`int`, `optional`, defaults to 0):
|
||||||
|
The id of the token to force as the last generated token when :obj:`max_length` is reached. Usually set to
|
||||||
|
:obj:`eos_token_id`.
|
||||||
|
|
||||||
Examples::
|
Examples::
|
||||||
|
|
||||||
@@ -127,6 +130,7 @@ class MarianConfig(PretrainedConfig):
|
|||||||
gradient_checkpointing=False,
|
gradient_checkpointing=False,
|
||||||
pad_token_id=58100,
|
pad_token_id=58100,
|
||||||
eos_token_id=0,
|
eos_token_id=0,
|
||||||
|
forced_eos_token_id=0,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
@@ -134,6 +138,7 @@ class MarianConfig(PretrainedConfig):
|
|||||||
eos_token_id=eos_token_id,
|
eos_token_id=eos_token_id,
|
||||||
is_encoder_decoder=is_encoder_decoder,
|
is_encoder_decoder=is_encoder_decoder,
|
||||||
decoder_start_token_id=decoder_start_token_id,
|
decoder_start_token_id=decoder_start_token_id,
|
||||||
|
forced_eos_token_id=forced_eos_token_id,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -1325,15 +1325,8 @@ class MarianMTModel(MarianPreTrainedModel):
|
|||||||
|
|
||||||
def adjust_logits_during_generation(self, logits, cur_len, max_length):
|
def adjust_logits_during_generation(self, logits, cur_len, max_length):
|
||||||
logits[:, self.config.pad_token_id] = float("-inf") # never predict pad token.
|
logits[:, self.config.pad_token_id] = float("-inf") # never predict pad token.
|
||||||
if cur_len == max_length - 1 and self.config.eos_token_id is not None:
|
|
||||||
self._force_token_id_to_be_generated(logits, self.config.eos_token_id)
|
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _force_token_id_to_be_generated(scores, token_id) -> None:
|
|
||||||
"""force one of token_ids to be generated by setting prob of all other tokens to 0 (logprob=-float("inf"))"""
|
|
||||||
scores[:, [x for x in range(scores.shape[1]) if x != token_id]] = -float("inf")
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _reorder_cache(past, beam_idx):
|
def _reorder_cache(past, beam_idx):
|
||||||
reordered_past = ()
|
reordered_past = ()
|
||||||
|
|||||||
@@ -1470,10 +1470,17 @@ class TFMarianMTModel(TFMarianPreTrainedModel, TFCausalLanguageModelingLoss):
|
|||||||
)
|
)
|
||||||
return (past[0], reordered_past)
|
return (past[0], reordered_past)
|
||||||
|
|
||||||
def adjust_logits_during_generation(self, logits, cur_len, max_length):
|
def adjust_logits_during_generation(
|
||||||
|
self, logits, cur_len, max_length, forced_bos_token_id, forced_eos_token_id, **kwargs
|
||||||
|
):
|
||||||
"""Never predict pad_token_id. Predict </s> when max_length is reached."""
|
"""Never predict pad_token_id. Predict </s> when max_length is reached."""
|
||||||
vocab_range = tf.constant(range(self.config.vocab_size))
|
vocab_range = tf.constant(range(self.config.vocab_size))
|
||||||
logits = tf.where(vocab_range == self.config.pad_token_id, LARGE_NEGATIVE, logits)
|
logits = tf.where(vocab_range == self.config.pad_token_id, LARGE_NEGATIVE, logits)
|
||||||
if cur_len == max_length - 1:
|
if cur_len == 1 and forced_bos_token_id is not None:
|
||||||
logits = tf.where(vocab_range != self.config.eos_token_id, LARGE_NEGATIVE, logits)
|
vocab_range = tf.constant(range(self.config.vocab_size))
|
||||||
|
return tf.where(vocab_range != forced_bos_token_id, LARGE_NEGATIVE, logits)
|
||||||
|
elif cur_len == max_length - 1 and forced_eos_token_id is not None:
|
||||||
|
vocab_range = tf.constant(range(self.config.vocab_size))
|
||||||
|
return tf.where(vocab_range != forced_eos_token_id, LARGE_NEGATIVE, logits)
|
||||||
|
else:
|
||||||
return logits
|
return logits
|
||||||
|
|||||||
@@ -84,6 +84,9 @@ class MBartConfig(PretrainedConfig):
|
|||||||
Scale embeddings by diving by sqrt(d_model).
|
Scale embeddings by diving by sqrt(d_model).
|
||||||
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||||
Whether or not the model should return the last key/values attentions (not used by all models)
|
Whether or not the model should return the last key/values attentions (not used by all models)
|
||||||
|
forced_eos_token_id (:obj:`int`, `optional`, defaults to 2):
|
||||||
|
The id of the token to force as the last generated token when :obj:`max_length` is reached. Usually set to
|
||||||
|
:obj:`eos_token_id`.
|
||||||
|
|
||||||
Example::
|
Example::
|
||||||
|
|
||||||
@@ -127,6 +130,7 @@ class MBartConfig(PretrainedConfig):
|
|||||||
pad_token_id=1,
|
pad_token_id=1,
|
||||||
bos_token_id=0,
|
bos_token_id=0,
|
||||||
eos_token_id=2,
|
eos_token_id=2,
|
||||||
|
forced_eos_token_id=2,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
@@ -134,6 +138,7 @@ class MBartConfig(PretrainedConfig):
|
|||||||
bos_token_id=bos_token_id,
|
bos_token_id=bos_token_id,
|
||||||
eos_token_id=eos_token_id,
|
eos_token_id=eos_token_id,
|
||||||
is_encoder_decoder=is_encoder_decoder,
|
is_encoder_decoder=is_encoder_decoder,
|
||||||
|
forced_eos_token_id=forced_eos_token_id,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -1344,16 +1344,6 @@ class MBartForConditionalGeneration(MBartPreTrainedModel):
|
|||||||
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
|
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
|
||||||
return shift_tokens_right(labels, self.config.pad_token_id)
|
return shift_tokens_right(labels, self.config.pad_token_id)
|
||||||
|
|
||||||
def adjust_logits_during_generation(self, logits, cur_len, max_length):
|
|
||||||
if cur_len == max_length - 1 and self.config.eos_token_id is not None:
|
|
||||||
self._force_token_id_to_be_generated(logits, self.config.eos_token_id)
|
|
||||||
return logits
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _force_token_id_to_be_generated(scores, token_id) -> None:
|
|
||||||
"""force one of token_ids to be generated by setting prob of all other tokens to 0 (logprob=-float("inf"))"""
|
|
||||||
scores[:, [x for x in range(scores.shape[1]) if x != token_id]] = -float("inf")
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _reorder_cache(past, beam_idx):
|
def _reorder_cache(past, beam_idx):
|
||||||
reordered_past = ()
|
reordered_past = ()
|
||||||
|
|||||||
@@ -1468,10 +1468,3 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel, TFCausalLanguageMo
|
|||||||
+ layer_past_key_values[2:],
|
+ layer_past_key_values[2:],
|
||||||
)
|
)
|
||||||
return (past[0], reordered_past)
|
return (past[0], reordered_past)
|
||||||
|
|
||||||
def adjust_logits_during_generation(self, logits, cur_len, max_length):
|
|
||||||
if cur_len == max_length - 1:
|
|
||||||
vocab_range = tf.constant(range(self.config.vocab_size))
|
|
||||||
return tf.where(vocab_range != self.config.eos_token_id, LARGE_NEGATIVE, logits)
|
|
||||||
else:
|
|
||||||
return logits
|
|
||||||
|
|||||||
@@ -84,6 +84,9 @@ class PegasusConfig(PretrainedConfig):
|
|||||||
Scale embeddings by diving by sqrt(d_model).
|
Scale embeddings by diving by sqrt(d_model).
|
||||||
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||||
Whether or not the model should return the last key/values attentions (not used by all models)
|
Whether or not the model should return the last key/values attentions (not used by all models)
|
||||||
|
forced_eos_token_id (:obj:`int`, `optional`, defaults to 1):
|
||||||
|
The id of the token to force as the last generated token when :obj:`max_length` is reached. Usually set to
|
||||||
|
:obj:`eos_token_id`.
|
||||||
|
|
||||||
Example::
|
Example::
|
||||||
|
|
||||||
@@ -127,6 +130,7 @@ class PegasusConfig(PretrainedConfig):
|
|||||||
gradient_checkpointing=False,
|
gradient_checkpointing=False,
|
||||||
pad_token_id=0,
|
pad_token_id=0,
|
||||||
eos_token_id=1,
|
eos_token_id=1,
|
||||||
|
forced_eos_token_id=1,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
@@ -134,6 +138,7 @@ class PegasusConfig(PretrainedConfig):
|
|||||||
eos_token_id=eos_token_id,
|
eos_token_id=eos_token_id,
|
||||||
is_encoder_decoder=is_encoder_decoder,
|
is_encoder_decoder=is_encoder_decoder,
|
||||||
decoder_start_token_id=decoder_start_token_id,
|
decoder_start_token_id=decoder_start_token_id,
|
||||||
|
forced_eos_token_id=forced_eos_token_id,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -1327,16 +1327,6 @@ class PegasusForConditionalGeneration(PegasusPreTrainedModel):
|
|||||||
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
|
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
|
||||||
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
|
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
|
||||||
|
|
||||||
def adjust_logits_during_generation(self, logits, cur_len, max_length):
|
|
||||||
if cur_len == max_length - 1 and self.config.eos_token_id is not None:
|
|
||||||
self._force_token_id_to_be_generated(logits, self.config.eos_token_id)
|
|
||||||
return logits
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _force_token_id_to_be_generated(scores, token_id) -> None:
|
|
||||||
"""force one of token_ids to be generated by setting prob of all other tokens to 0 (logprob=-float("inf"))"""
|
|
||||||
scores[:, [x for x in range(scores.shape[1]) if x != token_id]] = -float("inf")
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _reorder_cache(past, beam_idx):
|
def _reorder_cache(past, beam_idx):
|
||||||
reordered_past = ()
|
reordered_past = ()
|
||||||
|
|||||||
@@ -1483,10 +1483,3 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel, TFCausalLangua
|
|||||||
+ layer_past_key_values[2:],
|
+ layer_past_key_values[2:],
|
||||||
)
|
)
|
||||||
return (past[0], reordered_past)
|
return (past[0], reordered_past)
|
||||||
|
|
||||||
def adjust_logits_during_generation(self, logits, cur_len, max_length):
|
|
||||||
if cur_len == max_length - 1:
|
|
||||||
vocab_range = tf.constant(range(self.config.vocab_size))
|
|
||||||
return tf.where(vocab_range != self.config.eos_token_id, LARGE_NEGATIVE, logits)
|
|
||||||
else:
|
|
||||||
return logits
|
|
||||||
|
|||||||
@@ -74,6 +74,9 @@ RAG_CONFIG_DOC = r"""
|
|||||||
:obj:`context_attention_mask` are returned. See returned tensors for more detail.
|
:obj:`context_attention_mask` are returned. See returned tensors for more detail.
|
||||||
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||||
Whether or not the model should return the last key/values attentions (not used by all models).
|
Whether or not the model should return the last key/values attentions (not used by all models).
|
||||||
|
forced_eos_token_id (:obj:`int`, `optional`):
|
||||||
|
The id of the token to force as the last generated token when :obj:`max_length` is reached. Usually set to
|
||||||
|
:obj:`eos_token_id`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@@ -110,6 +113,7 @@ class RagConfig(PretrainedConfig):
|
|||||||
do_marginalize=False,
|
do_marginalize=False,
|
||||||
output_retrieved=False,
|
output_retrieved=False,
|
||||||
use_cache=True,
|
use_cache=True,
|
||||||
|
forced_eos_token_id=None,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
@@ -117,6 +121,7 @@ class RagConfig(PretrainedConfig):
|
|||||||
pad_token_id=pad_token_id,
|
pad_token_id=pad_token_id,
|
||||||
eos_token_id=eos_token_id,
|
eos_token_id=eos_token_id,
|
||||||
decoder_start_token_id=decoder_start_token_id,
|
decoder_start_token_id=decoder_start_token_id,
|
||||||
|
forced_eos_token_id=forced_eos_token_id,
|
||||||
is_encoder_decoder=is_encoder_decoder,
|
is_encoder_decoder=is_encoder_decoder,
|
||||||
prefix=prefix,
|
prefix=prefix,
|
||||||
vocab_size=vocab_size,
|
vocab_size=vocab_size,
|
||||||
@@ -161,6 +166,9 @@ class RagConfig(PretrainedConfig):
|
|||||||
|
|
||||||
self.use_cache = use_cache
|
self.use_cache = use_cache
|
||||||
|
|
||||||
|
if self.forced_eos_token_id is None:
|
||||||
|
self.forced_eos_token_id = getattr(self.generator, "forced_eos_token_id", None)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_question_encoder_generator_configs(
|
def from_question_encoder_generator_configs(
|
||||||
cls, question_encoder_config: PretrainedConfig, generator_config: PretrainedConfig, **kwargs
|
cls, question_encoder_config: PretrainedConfig, generator_config: PretrainedConfig, **kwargs
|
||||||
|
|||||||
@@ -1089,9 +1089,6 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
|||||||
def set_retriever(self, retriever: RagRetriever):
|
def set_retriever(self, retriever: RagRetriever):
|
||||||
self.rag.retriever = retriever
|
self.rag.retriever = retriever
|
||||||
|
|
||||||
def adjust_logits_during_generation(self, logits, cur_len, max_length):
|
|
||||||
return self.rag.generator.adjust_logits_during_generation(logits, cur_len=cur_len, max_length=max_length)
|
|
||||||
|
|
||||||
def prepare_inputs_for_generation(
|
def prepare_inputs_for_generation(
|
||||||
self,
|
self,
|
||||||
decoder_input_ids,
|
decoder_input_ids,
|
||||||
@@ -1313,6 +1310,8 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
|||||||
decoder_start_token_id=None,
|
decoder_start_token_id=None,
|
||||||
n_docs=None,
|
n_docs=None,
|
||||||
prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]] = None,
|
prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]] = None,
|
||||||
|
forced_bos_token_id: Optional[int] = None,
|
||||||
|
forced_eos_token_id: Optional[int] = None,
|
||||||
**model_kwargs
|
**model_kwargs
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -1403,6 +1402,12 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
|||||||
conditioned on the previously generated tokens :obj:`inputs_ids` and the batch ID :obj:`batch_id`. This
|
conditioned on the previously generated tokens :obj:`inputs_ids` and the batch ID :obj:`batch_id`. This
|
||||||
argument is useful for constrained generation conditioned on the prefix, as described in
|
argument is useful for constrained generation conditioned on the prefix, as described in
|
||||||
`Autoregressive Entity Retrieval <https://arxiv.org/abs/2010.00904>`__.
|
`Autoregressive Entity Retrieval <https://arxiv.org/abs/2010.00904>`__.
|
||||||
|
forced_bos_token_id (:obj:`int`, `optional`):
|
||||||
|
The id of the token to force as the first generated token after the :obj:`decoder_start_token_id`.
|
||||||
|
Useful for multilingual models like :doc:`mBART <../model_doc/mbart>` where the first generated token
|
||||||
|
needs to be the target language token.
|
||||||
|
forced_eos_token_id (:obj:`int`, `optional`):
|
||||||
|
The id of the token to force as the last generated token when :obj:`max_length` is reached.
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
:obj:`torch.LongTensor` of shape :obj:`(batch_size * num_return_sequences, sequence_length)`: The generated
|
:obj:`torch.LongTensor` of shape :obj:`(batch_size * num_return_sequences, sequence_length)`: The generated
|
||||||
@@ -1498,7 +1503,10 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
|||||||
encoder_input_ids=context_input_ids,
|
encoder_input_ids=context_input_ids,
|
||||||
bad_words_ids=bad_words_ids,
|
bad_words_ids=bad_words_ids,
|
||||||
min_length=min_length,
|
min_length=min_length,
|
||||||
|
max_length=max_length,
|
||||||
eos_token_id=eos_token_id,
|
eos_token_id=eos_token_id,
|
||||||
|
forced_bos_token_id=forced_bos_token_id,
|
||||||
|
forced_eos_token_id=forced_eos_token_id,
|
||||||
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
|
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
|
||||||
num_beams=num_beams,
|
num_beams=num_beams,
|
||||||
num_beam_groups=num_beam_groups,
|
num_beam_groups=num_beam_groups,
|
||||||
|
|||||||
@@ -28,6 +28,8 @@ if is_torch_available():
|
|||||||
|
|
||||||
from transformers.generation_logits_process import (
|
from transformers.generation_logits_process import (
|
||||||
EncoderNoRepeatNGramLogitsProcessor,
|
EncoderNoRepeatNGramLogitsProcessor,
|
||||||
|
ForcedBOSTokenLogitsProcessor,
|
||||||
|
ForcedEOSTokenLogitsProcessor,
|
||||||
HammingDiversityLogitsProcessor,
|
HammingDiversityLogitsProcessor,
|
||||||
LogitsProcessorList,
|
LogitsProcessorList,
|
||||||
MinLengthLogitsProcessor,
|
MinLengthLogitsProcessor,
|
||||||
@@ -393,3 +395,44 @@ class LogitsProcessorTest(unittest.TestCase):
|
|||||||
processed_scores[1], torch.tensor([0.2500, -0.7500, 0.2500, 0.2500], device=torch_device), atol=1e-3
|
processed_scores[1], torch.tensor([0.2500, -0.7500, 0.2500, 0.2500], device=torch_device), atol=1e-3
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_forced_bos_token_logits_processor(self):
|
||||||
|
vocab_size = 20
|
||||||
|
batch_size = 4
|
||||||
|
bos_token_id = 0
|
||||||
|
|
||||||
|
logits_processor = ForcedBOSTokenLogitsProcessor(bos_token_id=bos_token_id)
|
||||||
|
|
||||||
|
# check that all scores are -inf except the bos_token_id score
|
||||||
|
input_ids = ids_tensor((batch_size, 1), vocab_size=20)
|
||||||
|
scores = self._get_uniform_logits(batch_size, vocab_size)
|
||||||
|
scores = logits_processor(input_ids, scores)
|
||||||
|
self.assertTrue(torch.isneginf(scores[:, bos_token_id + 1 :]).all())
|
||||||
|
self.assertListEqual(scores[:, bos_token_id].tolist(), 4 * [0]) # score for bos_token_id shold be zero
|
||||||
|
|
||||||
|
# check that bos_token_id is not forced if current length is greater than 1
|
||||||
|
input_ids = ids_tensor((batch_size, 4), vocab_size=20)
|
||||||
|
scores = self._get_uniform_logits(batch_size, vocab_size)
|
||||||
|
scores = logits_processor(input_ids, scores)
|
||||||
|
self.assertFalse(torch.isinf(scores).any())
|
||||||
|
|
||||||
|
def test_forced_eos_token_logits_processor(self):
|
||||||
|
vocab_size = 20
|
||||||
|
batch_size = 4
|
||||||
|
eos_token_id = 0
|
||||||
|
max_length = 5
|
||||||
|
|
||||||
|
logits_processor = ForcedEOSTokenLogitsProcessor(max_length=max_length, eos_token_id=eos_token_id)
|
||||||
|
|
||||||
|
# check that all scores are -inf except the eos_token_id when max_length is reached
|
||||||
|
input_ids = ids_tensor((batch_size, 4), vocab_size=20)
|
||||||
|
scores = self._get_uniform_logits(batch_size, vocab_size)
|
||||||
|
scores = logits_processor(input_ids, scores)
|
||||||
|
self.assertTrue(torch.isneginf(scores[:, eos_token_id + 1 :]).all())
|
||||||
|
self.assertListEqual(scores[:, eos_token_id].tolist(), 4 * [0]) # score for eos_token_id should be zero
|
||||||
|
|
||||||
|
# check that eos_token_id is not forced if max_length is not reached
|
||||||
|
input_ids = ids_tensor((batch_size, 3), vocab_size=20)
|
||||||
|
scores = self._get_uniform_logits(batch_size, vocab_size)
|
||||||
|
scores = logits_processor(input_ids, scores)
|
||||||
|
self.assertFalse(torch.isinf(scores).any())
|
||||||
|
|||||||
@@ -26,6 +26,8 @@ if is_torch_available():
|
|||||||
from transformers import BartForConditionalGeneration, BartTokenizer, top_k_top_p_filtering
|
from transformers import BartForConditionalGeneration, BartTokenizer, top_k_top_p_filtering
|
||||||
from transformers.generation_beam_search import BeamSearchScorer
|
from transformers.generation_beam_search import BeamSearchScorer
|
||||||
from transformers.generation_logits_process import (
|
from transformers.generation_logits_process import (
|
||||||
|
ForcedBOSTokenLogitsProcessor,
|
||||||
|
ForcedEOSTokenLogitsProcessor,
|
||||||
HammingDiversityLogitsProcessor,
|
HammingDiversityLogitsProcessor,
|
||||||
LogitsProcessorList,
|
LogitsProcessorList,
|
||||||
MinLengthLogitsProcessor,
|
MinLengthLogitsProcessor,
|
||||||
@@ -70,7 +72,14 @@ class GenerationTesterMixin:
|
|||||||
return config, input_ids, attention_mask, max_length
|
return config, input_ids, attention_mask, max_length
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_logits_processor_and_kwargs(input_length, eos_token_id, diversity_penalty=None):
|
def _get_logits_processor_and_kwargs(
|
||||||
|
input_length,
|
||||||
|
eos_token_id,
|
||||||
|
forced_bos_token_id=None,
|
||||||
|
forced_eos_token_id=None,
|
||||||
|
max_length=None,
|
||||||
|
diversity_penalty=None,
|
||||||
|
):
|
||||||
process_kwargs = {
|
process_kwargs = {
|
||||||
"min_length": input_length + 1,
|
"min_length": input_length + 1,
|
||||||
"bad_words_ids": [[1, 0]],
|
"bad_words_ids": [[1, 0]],
|
||||||
@@ -92,6 +101,18 @@ class GenerationTesterMixin:
|
|||||||
if eos_token_id is not None
|
if eos_token_id is not None
|
||||||
else []
|
else []
|
||||||
)
|
)
|
||||||
|
+ (
|
||||||
|
[
|
||||||
|
ForcedBOSTokenLogitsProcessor(forced_bos_token_id),
|
||||||
|
]
|
||||||
|
if forced_bos_token_id is not None
|
||||||
|
else []
|
||||||
|
)
|
||||||
|
+ (
|
||||||
|
[ForcedEOSTokenLogitsProcessor(max_length, forced_eos_token_id)]
|
||||||
|
if forced_eos_token_id is not None
|
||||||
|
else []
|
||||||
|
)
|
||||||
+ [
|
+ [
|
||||||
NoBadWordsLogitsProcessor(process_kwargs["bad_words_ids"], eos_token_id),
|
NoBadWordsLogitsProcessor(process_kwargs["bad_words_ids"], eos_token_id),
|
||||||
NoRepeatNGramLogitsProcessor(process_kwargs["no_repeat_ngram_size"]),
|
NoRepeatNGramLogitsProcessor(process_kwargs["no_repeat_ngram_size"]),
|
||||||
@@ -182,13 +203,17 @@ class GenerationTesterMixin:
|
|||||||
output_hidden_states=False,
|
output_hidden_states=False,
|
||||||
return_dict_in_generate=False,
|
return_dict_in_generate=False,
|
||||||
):
|
):
|
||||||
|
if model.config.is_encoder_decoder:
|
||||||
|
max_length = 4
|
||||||
logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
|
logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
|
||||||
input_ids.shape[-1], model.config.eos_token_id
|
input_ids.shape[-1],
|
||||||
|
eos_token_id=model.config.eos_token_id,
|
||||||
|
forced_bos_token_id=model.config.forced_bos_token_id,
|
||||||
|
forced_eos_token_id=model.config.forced_eos_token_id,
|
||||||
|
max_length=max_length,
|
||||||
)
|
)
|
||||||
|
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
if model.config.is_encoder_decoder:
|
|
||||||
max_length = 4
|
|
||||||
|
|
||||||
output_generate = model.generate(
|
output_generate = model.generate(
|
||||||
input_ids,
|
input_ids,
|
||||||
@@ -544,14 +569,19 @@ class GenerationTesterMixin:
|
|||||||
for model_class in self.all_generative_model_classes:
|
for model_class in self.all_generative_model_classes:
|
||||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
|
|
||||||
input_ids.shape[-1], model.config.eos_token_id
|
|
||||||
)
|
|
||||||
logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1)
|
|
||||||
|
|
||||||
if model.config.is_encoder_decoder:
|
if model.config.is_encoder_decoder:
|
||||||
max_length = 4
|
max_length = 4
|
||||||
|
|
||||||
|
process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
|
||||||
|
input_ids.shape[-1],
|
||||||
|
model.config.eos_token_id,
|
||||||
|
forced_bos_token_id=model.config.forced_bos_token_id,
|
||||||
|
forced_eos_token_id=model.config.forced_eos_token_id,
|
||||||
|
max_length=max_length,
|
||||||
|
)
|
||||||
|
logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1)
|
||||||
|
|
||||||
# check `generate()` and `sample()` are equal
|
# check `generate()` and `sample()` are equal
|
||||||
output_sample, output_generate = self._sample_generate(
|
output_sample, output_generate = self._sample_generate(
|
||||||
model=model,
|
model=model,
|
||||||
@@ -586,14 +616,18 @@ class GenerationTesterMixin:
|
|||||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
||||||
config.use_cache = False
|
config.use_cache = False
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
|
|
||||||
input_ids.shape[-1], model.config.eos_token_id
|
|
||||||
)
|
|
||||||
logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1)
|
|
||||||
|
|
||||||
if model.config.is_encoder_decoder:
|
if model.config.is_encoder_decoder:
|
||||||
max_length = 4
|
max_length = 4
|
||||||
|
|
||||||
|
process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
|
||||||
|
input_ids.shape[-1],
|
||||||
|
model.config.eos_token_id,
|
||||||
|
forced_bos_token_id=model.config.forced_bos_token_id,
|
||||||
|
forced_eos_token_id=model.config.forced_eos_token_id,
|
||||||
|
max_length=max_length,
|
||||||
|
)
|
||||||
|
logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1)
|
||||||
|
|
||||||
output_sample, output_generate = self._sample_generate(
|
output_sample, output_generate = self._sample_generate(
|
||||||
model=model,
|
model=model,
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
@@ -630,14 +664,19 @@ class GenerationTesterMixin:
|
|||||||
# shorter than `max_length` can be generated which could lead to flaky circle ci
|
# shorter than `max_length` can be generated which could lead to flaky circle ci
|
||||||
# failures if the top `num_return_sequences` beams are all shorter than the longest beam
|
# failures if the top `num_return_sequences` beams are all shorter than the longest beam
|
||||||
config.eos_token_id = None
|
config.eos_token_id = None
|
||||||
|
config.forced_eos_token_id = None
|
||||||
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
|
|
||||||
logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
|
|
||||||
input_ids.shape[-1], config.eos_token_id
|
|
||||||
)
|
|
||||||
if model.config.is_encoder_decoder:
|
if model.config.is_encoder_decoder:
|
||||||
max_length = 4
|
max_length = 4
|
||||||
|
|
||||||
|
logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
|
||||||
|
input_ids.shape[-1],
|
||||||
|
config.eos_token_id,
|
||||||
|
config.forced_bos_token_id,
|
||||||
|
config.forced_eos_token_id,
|
||||||
|
max_length,
|
||||||
|
)
|
||||||
beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(input_ids.shape[0], max_length)
|
beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(input_ids.shape[0], max_length)
|
||||||
|
|
||||||
# check `generate()` and `beam_search()` are equal
|
# check `generate()` and `beam_search()` are equal
|
||||||
@@ -684,13 +723,19 @@ class GenerationTesterMixin:
|
|||||||
# shorter than `max_length` can be generated which could lead to flaky circle ci
|
# shorter than `max_length` can be generated which could lead to flaky circle ci
|
||||||
# failures if the top `num_return_sequences` beams are all shorter than the longest beam
|
# failures if the top `num_return_sequences` beams are all shorter than the longest beam
|
||||||
config.eos_token_id = None
|
config.eos_token_id = None
|
||||||
|
config.forced_eos_token_id = None
|
||||||
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
|
|
||||||
input_ids.shape[-1], config.eos_token_id
|
|
||||||
)
|
|
||||||
if model.config.is_encoder_decoder:
|
if model.config.is_encoder_decoder:
|
||||||
max_length = 4
|
max_length = 4
|
||||||
|
|
||||||
|
logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
|
||||||
|
input_ids.shape[-1],
|
||||||
|
config.eos_token_id,
|
||||||
|
config.forced_bos_token_id,
|
||||||
|
config.forced_eos_token_id,
|
||||||
|
max_length,
|
||||||
|
)
|
||||||
beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(input_ids.shape[0], max_length)
|
beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(input_ids.shape[0], max_length)
|
||||||
output_generate, output_beam_search = self._beam_search_generate(
|
output_generate, output_beam_search = self._beam_search_generate(
|
||||||
model=model,
|
model=model,
|
||||||
@@ -732,19 +777,24 @@ class GenerationTesterMixin:
|
|||||||
# shorter than `max_length` can be generated which could lead to flaky circle ci
|
# shorter than `max_length` can be generated which could lead to flaky circle ci
|
||||||
# failures if the top `num_return_sequences` beams are all shorter than the longest beam
|
# failures if the top `num_return_sequences` beams are all shorter than the longest beam
|
||||||
config.eos_token_id = None
|
config.eos_token_id = None
|
||||||
|
config.forced_eos_token_id = None
|
||||||
|
|
||||||
if not hasattr(config, "use_cache"):
|
if not hasattr(config, "use_cache"):
|
||||||
# only relevant if model has "use_cache"
|
# only relevant if model has "use_cache"
|
||||||
return
|
return
|
||||||
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
|
|
||||||
logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
|
|
||||||
input_ids.shape[-1], config.eos_token_id
|
|
||||||
)
|
|
||||||
|
|
||||||
if model.config.is_encoder_decoder:
|
if model.config.is_encoder_decoder:
|
||||||
max_length = 4
|
max_length = 4
|
||||||
|
|
||||||
|
logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
|
||||||
|
input_ids.shape[-1],
|
||||||
|
config.eos_token_id,
|
||||||
|
config.forced_bos_token_id,
|
||||||
|
config.forced_eos_token_id,
|
||||||
|
max_length,
|
||||||
|
)
|
||||||
|
|
||||||
beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(input_ids.shape[0], max_length)
|
beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(input_ids.shape[0], max_length)
|
||||||
|
|
||||||
config.use_cache = True
|
config.use_cache = True
|
||||||
@@ -780,6 +830,7 @@ class GenerationTesterMixin:
|
|||||||
# shorter than `max_length` can be generated which could lead to flaky circle ci
|
# shorter than `max_length` can be generated which could lead to flaky circle ci
|
||||||
# failures if the top `num_return_sequences` beams are all shorter than the longest beam
|
# failures if the top `num_return_sequences` beams are all shorter than the longest beam
|
||||||
config.eos_token_id = None
|
config.eos_token_id = None
|
||||||
|
config.forced_eos_token_id = None
|
||||||
|
|
||||||
logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1)
|
logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1)
|
||||||
|
|
||||||
@@ -819,6 +870,7 @@ class GenerationTesterMixin:
|
|||||||
# shorter than `max_length` can be generated which could lead to flaky circle ci
|
# shorter than `max_length` can be generated which could lead to flaky circle ci
|
||||||
# failures if the top `num_return_sequences` beams are all shorter than the longest beam
|
# failures if the top `num_return_sequences` beams are all shorter than the longest beam
|
||||||
config.eos_token_id = None
|
config.eos_token_id = None
|
||||||
|
config.forced_eos_token_id = None
|
||||||
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1)
|
logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1)
|
||||||
@@ -892,16 +944,22 @@ class GenerationTesterMixin:
|
|||||||
# shorter than `max_length` can be generated which could lead to flaky circle ci
|
# shorter than `max_length` can be generated which could lead to flaky circle ci
|
||||||
# failures if the top `num_return_sequences` beams are all shorter than the longest beam
|
# failures if the top `num_return_sequences` beams are all shorter than the longest beam
|
||||||
config.eos_token_id = None
|
config.eos_token_id = None
|
||||||
|
config.forced_eos_token_id = None
|
||||||
logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
|
|
||||||
input_ids.shape[-1], config.eos_token_id, diversity_penalty=2.0
|
|
||||||
)
|
|
||||||
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
|
|
||||||
# check `generate()` and `group_beam_search()` are equal
|
|
||||||
if model.config.is_encoder_decoder:
|
if model.config.is_encoder_decoder:
|
||||||
max_length = 4
|
max_length = 4
|
||||||
|
|
||||||
|
logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
|
||||||
|
input_ids.shape[-1],
|
||||||
|
config.eos_token_id,
|
||||||
|
config.forced_bos_token_id,
|
||||||
|
config.forced_eos_token_id,
|
||||||
|
max_length,
|
||||||
|
diversity_penalty=2.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
# check `generate()` and `group_beam_search()` are equal
|
||||||
beam_kwargs, beam_scorer = self._get_diverse_beam_scorer_and_kwargs(input_ids.shape[0], max_length)
|
beam_kwargs, beam_scorer = self._get_diverse_beam_scorer_and_kwargs(input_ids.shape[0], max_length)
|
||||||
output_generate, output_group_beam_search = self._group_beam_search_generate(
|
output_generate, output_group_beam_search = self._group_beam_search_generate(
|
||||||
model=model,
|
model=model,
|
||||||
@@ -943,16 +1001,22 @@ class GenerationTesterMixin:
|
|||||||
# shorter than `max_length` can be generated which could lead to flaky circle ci
|
# shorter than `max_length` can be generated which could lead to flaky circle ci
|
||||||
# failures if the top `num_return_sequences` beams are all shorter than the longest beam
|
# failures if the top `num_return_sequences` beams are all shorter than the longest beam
|
||||||
config.eos_token_id = None
|
config.eos_token_id = None
|
||||||
|
config.forced_eos_token_id = None
|
||||||
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
|
if model.config.is_encoder_decoder:
|
||||||
|
max_length = 4
|
||||||
|
|
||||||
logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
|
logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
|
||||||
input_ids.shape[-1], config.eos_token_id, diversity_penalty=2.0
|
input_ids.shape[-1],
|
||||||
|
config.eos_token_id,
|
||||||
|
config.forced_bos_token_id,
|
||||||
|
config.forced_eos_token_id,
|
||||||
|
max_length,
|
||||||
|
diversity_penalty=2.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
num_return_sequences = 1
|
num_return_sequences = 1
|
||||||
if model.config.is_encoder_decoder:
|
|
||||||
max_length = 4
|
|
||||||
beam_kwargs, beam_scorer = self._get_diverse_beam_scorer_and_kwargs(
|
beam_kwargs, beam_scorer = self._get_diverse_beam_scorer_and_kwargs(
|
||||||
input_ids.shape[0], max_length, num_return_sequences=num_return_sequences
|
input_ids.shape[0], max_length, num_return_sequences=num_return_sequences
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -46,6 +46,7 @@ class SimpleSummarizationPipelineTests(unittest.TestCase):
|
|||||||
decoder_attention_heads=1,
|
decoder_attention_heads=1,
|
||||||
max_length=4,
|
max_length=4,
|
||||||
min_length=1,
|
min_length=1,
|
||||||
|
forced_eos_token_id=None,
|
||||||
)
|
)
|
||||||
model = BartForConditionalGeneration(config)
|
model = BartForConditionalGeneration(config)
|
||||||
# Bias output towards L
|
# Bias output towards L
|
||||||
|
|||||||
Reference in New Issue
Block a user