[Generate] Add save mode logits processor to remove nans and infs if necessary (#10769)
* push * finish * finish * make fix copies * change name
This commit is contained in:
committed by
GitHub
parent
9f8fa4e973
commit
77bf3fe787
@@ -151,6 +151,16 @@ generation.
|
||||
.. autoclass:: transformers.HammingDiversityLogitsProcessor
|
||||
:members: __call__
|
||||
|
||||
.. autoclass:: transformers.ForcedBOSTokenLogitsProcessor
|
||||
:members: __call__
|
||||
|
||||
.. autoclass:: transformers.ForcedEOSTokenLogitsProcessor
|
||||
:members: __call__
|
||||
|
||||
.. autoclass:: transformers.InfNanRemoveLogitsProcessor
|
||||
:members: __call__
|
||||
|
||||
|
||||
StoppingCriteria
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
|
||||
@@ -369,7 +369,10 @@ if is_torch_available():
|
||||
]
|
||||
_import_structure["generation_beam_search"] = ["BeamScorer", "BeamSearchScorer"]
|
||||
_import_structure["generation_logits_process"] = [
|
||||
"ForcedBOSTokenLogitsProcessor",
|
||||
"ForcedEOSTokenLogitsProcessor",
|
||||
"HammingDiversityLogitsProcessor",
|
||||
"InfNanRemoveLogitsProcessor",
|
||||
"LogitsProcessor",
|
||||
"LogitsProcessorList",
|
||||
"LogitsWarper",
|
||||
@@ -1560,7 +1563,10 @@ if TYPE_CHECKING:
|
||||
)
|
||||
from .generation_beam_search import BeamScorer, BeamSearchScorer
|
||||
from .generation_logits_process import (
|
||||
ForcedBOSTokenLogitsProcessor,
|
||||
ForcedEOSTokenLogitsProcessor,
|
||||
HammingDiversityLogitsProcessor,
|
||||
InfNanRemoveLogitsProcessor,
|
||||
LogitsProcessor,
|
||||
LogitsProcessorList,
|
||||
LogitsWarper,
|
||||
|
||||
@@ -134,6 +134,9 @@ class PretrainedConfig(object):
|
||||
<../model_doc/mbart>` where the first generated token needs to be the target language token.
|
||||
- **forced_eos_token_id** (:obj:`int`, `optional`) -- The id of the token to force as the last generated token
|
||||
when :obj:`max_length` is reached.
|
||||
- **remove_invalid_values** (:obj:`bool`, `optional`) -- Whether to remove possible `nan` and `inf` outputs of
|
||||
the model to prevent the generation method to crash. Note that using ``remove_invalid_values`` can slow down
|
||||
generation.
|
||||
|
||||
|
||||
Parameters for fine-tuning tasks
|
||||
@@ -219,6 +222,7 @@ class PretrainedConfig(object):
|
||||
self.return_dict_in_generate = kwargs.pop("return_dict_in_generate", False)
|
||||
self.forced_bos_token_id = kwargs.pop("forced_bos_token_id", None)
|
||||
self.forced_eos_token_id = kwargs.pop("forced_eos_token_id", None)
|
||||
self.remove_invalid_values = kwargs.pop("remove_invalid_values", False)
|
||||
|
||||
# Fine-tuning task arguments
|
||||
self.architectures = kwargs.pop("architectures", None)
|
||||
|
||||
@@ -566,3 +566,20 @@ class ForcedEOSTokenLogitsProcessor(LogitsProcessor):
|
||||
scores[:, [i for i in range(num_tokens) if i != self.eos_token_id]] = -float("inf")
|
||||
scores[:, self.eos_token_id] = 0
|
||||
return scores
|
||||
|
||||
|
||||
class InfNanRemoveLogitsProcessor(LogitsProcessor):
|
||||
r"""
|
||||
:class:`~transformers.LogitsProcessor` that removes all :obj:`nan` and :obj:`inf` values to avoid the generation
|
||||
method to fail. Note that using the logits processor should only be used if necessary since it can slow down the
|
||||
generation method. :obj:`max_length` is reached.
|
||||
"""
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
# set all nan values to 0.0
|
||||
scores[scores != scores] = 0.0
|
||||
|
||||
# set all inf values to max possible value
|
||||
scores[scores == float("inf")] = torch.finfo(scores.dtype).max
|
||||
|
||||
return scores
|
||||
|
||||
@@ -27,6 +27,7 @@ from .generation_logits_process import (
|
||||
ForcedBOSTokenLogitsProcessor,
|
||||
ForcedEOSTokenLogitsProcessor,
|
||||
HammingDiversityLogitsProcessor,
|
||||
InfNanRemoveLogitsProcessor,
|
||||
LogitsProcessorList,
|
||||
MinLengthLogitsProcessor,
|
||||
NoBadWordsLogitsProcessor,
|
||||
@@ -581,6 +582,7 @@ class GenerationMixin:
|
||||
num_beams: int,
|
||||
num_beam_groups: int,
|
||||
diversity_penalty: float,
|
||||
remove_invalid_values: bool,
|
||||
) -> LogitsProcessorList:
|
||||
"""
|
||||
This class returns a :obj:`~transformers.LogitsProcessorList` list object that contains all relevant
|
||||
@@ -607,6 +609,9 @@ class GenerationMixin:
|
||||
forced_eos_token_id = (
|
||||
forced_eos_token_id if forced_eos_token_id is not None else self.config.forced_eos_token_id
|
||||
)
|
||||
remove_invalid_values = (
|
||||
remove_invalid_values if remove_invalid_values is not None else self.config.remove_invalid_values
|
||||
)
|
||||
# instantiate processors list
|
||||
processors = LogitsProcessorList()
|
||||
|
||||
@@ -639,6 +644,8 @@ class GenerationMixin:
|
||||
processors.append(ForcedBOSTokenLogitsProcessor(forced_bos_token_id))
|
||||
if forced_eos_token_id is not None:
|
||||
processors.append(ForcedEOSTokenLogitsProcessor(max_length, forced_eos_token_id))
|
||||
if remove_invalid_values is True:
|
||||
processors.append(InfNanRemoveLogitsProcessor())
|
||||
return processors
|
||||
|
||||
def _get_stopping_criteria(
|
||||
@@ -687,6 +694,7 @@ class GenerationMixin:
|
||||
return_dict_in_generate: Optional[bool] = None,
|
||||
forced_bos_token_id: Optional[int] = None,
|
||||
forced_eos_token_id: Optional[int] = None,
|
||||
remove_invalid_values: Optional[bool] = None,
|
||||
**model_kwargs,
|
||||
) -> Union[GreedySearchOutput, SampleOutput, BeamSearchOutput, BeamSampleOutput, torch.LongTensor]:
|
||||
r"""
|
||||
@@ -789,6 +797,9 @@ class GenerationMixin:
|
||||
needs to be the target language token.
|
||||
forced_eos_token_id (:obj:`int`, `optional`):
|
||||
The id of the token to force as the last generated token when :obj:`max_length` is reached.
|
||||
remove_invalid_values (:obj:`bool`, `optional`):
|
||||
Whether to remove possible `nan` and `inf` outputs of the model to prevent the generation method to
|
||||
crash. Note that using ``remove_invalid_values`` can slow down generation.
|
||||
|
||||
model_kwargs:
|
||||
Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model. If the
|
||||
@@ -965,6 +976,7 @@ class GenerationMixin:
|
||||
num_beams=num_beams,
|
||||
num_beam_groups=num_beam_groups,
|
||||
diversity_penalty=diversity_penalty,
|
||||
remove_invalid_values=remove_invalid_values,
|
||||
)
|
||||
|
||||
stopping_criteria = self._get_stopping_criteria(
|
||||
@@ -1511,6 +1523,7 @@ class GenerationMixin:
|
||||
|
||||
# sample
|
||||
probs = F.softmax(next_token_scores, dim=-1)
|
||||
|
||||
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
||||
|
||||
# add code that transfomers next_tokens to tokens_to_add
|
||||
@@ -2026,6 +2039,7 @@ class GenerationMixin:
|
||||
next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)
|
||||
|
||||
probs = F.softmax(next_token_scores, dim=-1)
|
||||
|
||||
next_tokens = torch.multinomial(probs, num_samples=2 * num_beams)
|
||||
next_token_scores = torch.gather(next_token_scores, -1, next_tokens)
|
||||
|
||||
|
||||
@@ -1316,6 +1316,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
||||
prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]] = None,
|
||||
forced_bos_token_id: Optional[int] = None,
|
||||
forced_eos_token_id: Optional[int] = None,
|
||||
remove_invalid_values: Optional[bool] = None,
|
||||
**model_kwargs
|
||||
):
|
||||
"""
|
||||
@@ -1412,6 +1413,9 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
||||
needs to be the target language token.
|
||||
forced_eos_token_id (:obj:`int`, `optional`):
|
||||
The id of the token to force as the last generated token when :obj:`max_length` is reached.
|
||||
remove_invalid_values (:obj:`bool`, `optional`):
|
||||
Whether to remove possible `nan` and `inf` outputs of the model to prevent the generation method to
|
||||
crash. Note that using ``remove_invalid_values`` can slow down generation.
|
||||
|
||||
Return:
|
||||
:obj:`torch.LongTensor` of shape :obj:`(batch_size * num_return_sequences, sequence_length)`: The generated
|
||||
@@ -1435,6 +1439,9 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
||||
if decoder_start_token_id is not None
|
||||
else self.config.generator.decoder_start_token_id
|
||||
)
|
||||
remove_invalid_values = (
|
||||
remove_invalid_values if remove_invalid_values is not None else self.config.remove_invalid_values
|
||||
)
|
||||
|
||||
# retrieve docs
|
||||
if self.retriever is not None and context_input_ids is None:
|
||||
@@ -1515,6 +1522,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
||||
num_beams=num_beams,
|
||||
num_beam_groups=num_beam_groups,
|
||||
diversity_penalty=diversity_penalty,
|
||||
remove_invalid_values=remove_invalid_values,
|
||||
)
|
||||
|
||||
if num_beams == 1:
|
||||
|
||||
@@ -123,11 +123,26 @@ class BeamSearchScorer:
|
||||
requires_pytorch(self)
|
||||
|
||||
|
||||
class ForcedBOSTokenLogitsProcessor:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
|
||||
class ForcedEOSTokenLogitsProcessor:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
|
||||
class HammingDiversityLogitsProcessor:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
|
||||
class InfNanRemoveLogitsProcessor:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
|
||||
class LogitsProcessor:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
@@ -31,6 +31,7 @@ if is_torch_available():
|
||||
ForcedBOSTokenLogitsProcessor,
|
||||
ForcedEOSTokenLogitsProcessor,
|
||||
HammingDiversityLogitsProcessor,
|
||||
InfNanRemoveLogitsProcessor,
|
||||
LogitsProcessorList,
|
||||
MinLengthLogitsProcessor,
|
||||
NoBadWordsLogitsProcessor,
|
||||
@@ -436,3 +437,24 @@ class LogitsProcessorTest(unittest.TestCase):
|
||||
scores = self._get_uniform_logits(batch_size, vocab_size)
|
||||
scores = logits_processor(input_ids, scores)
|
||||
self.assertFalse(torch.isinf(scores).any())
|
||||
|
||||
def test_remove_nan_inf_logits_processor(self):
|
||||
scores = torch.tensor(
|
||||
[[0.0, 0.7, 0.8, float("nan")], [0.1, float("inf"), 0.3, float("-inf")]], device=torch_device
|
||||
)
|
||||
input_ids = ids_tensor((2, 4), vocab_size=20)
|
||||
|
||||
logits_processor = InfNanRemoveLogitsProcessor()
|
||||
|
||||
scores = logits_processor(input_ids, scores)
|
||||
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
scores,
|
||||
torch.tensor(
|
||||
[[0.0, 0.7, 0.8, 0.0], [0.1, torch.finfo(scores.dtype).max, 0.3, float("-inf")]],
|
||||
device=torch_device,
|
||||
),
|
||||
atol=1e-6,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -29,6 +29,7 @@ if is_torch_available():
|
||||
ForcedBOSTokenLogitsProcessor,
|
||||
ForcedEOSTokenLogitsProcessor,
|
||||
HammingDiversityLogitsProcessor,
|
||||
InfNanRemoveLogitsProcessor,
|
||||
LogitsProcessorList,
|
||||
MinLengthLogitsProcessor,
|
||||
NoBadWordsLogitsProcessor,
|
||||
@@ -229,6 +230,7 @@ class GenerationTesterMixin:
|
||||
output_hidden_states=output_hidden_states,
|
||||
output_scores=output_scores,
|
||||
return_dict_in_generate=return_dict_in_generate,
|
||||
remove_invalid_values=True,
|
||||
**logits_process_kwargs,
|
||||
)
|
||||
|
||||
@@ -284,6 +286,7 @@ class GenerationTesterMixin:
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict_in_generate=return_dict_in_generate,
|
||||
remove_invalid_values=True,
|
||||
**logits_warper_kwargs,
|
||||
**process_kwargs,
|
||||
)
|
||||
@@ -305,6 +308,10 @@ class GenerationTesterMixin:
|
||||
attention_mask_clone = attention_mask.repeat_interleave(num_return_sequences, dim=0)
|
||||
input_ids_clone = input_ids.repeat_interleave(num_return_sequences, dim=0)
|
||||
|
||||
# prevent flaky generation test failures
|
||||
logits_processor.append(InfNanRemoveLogitsProcessor())
|
||||
|
||||
with torch.no_grad():
|
||||
with torch.no_grad():
|
||||
output_sample = model.sample(
|
||||
input_ids_clone,
|
||||
@@ -344,6 +351,7 @@ class GenerationTesterMixin:
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict_in_generate=return_dict_in_generate,
|
||||
remove_invalid_values=True,
|
||||
**beam_kwargs,
|
||||
**logits_process_kwargs,
|
||||
)
|
||||
@@ -406,6 +414,7 @@ class GenerationTesterMixin:
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict_in_generate=return_dict_in_generate,
|
||||
remove_invalid_values=True,
|
||||
**beam_kwargs,
|
||||
**logits_warper_kwargs,
|
||||
)
|
||||
@@ -424,6 +433,10 @@ class GenerationTesterMixin:
|
||||
else:
|
||||
attention_mask = attention_mask.repeat_interleave(beam_scorer.num_beams * num_return_sequences, dim=0)
|
||||
|
||||
# prevent flaky generation test failures
|
||||
logits_processor = LogitsProcessorList()
|
||||
logits_processor.append(InfNanRemoveLogitsProcessor())
|
||||
|
||||
torch.manual_seed(0)
|
||||
with torch.no_grad():
|
||||
output_beam_sample = model.beam_sample(
|
||||
@@ -432,6 +445,7 @@ class GenerationTesterMixin:
|
||||
max_length=max_length,
|
||||
attention_mask=attention_mask,
|
||||
logits_warper=logits_warper,
|
||||
logits_processor=logits_processor,
|
||||
output_scores=output_scores,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
@@ -465,6 +479,7 @@ class GenerationTesterMixin:
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict_in_generate=return_dict_in_generate,
|
||||
remove_invalid_values=True,
|
||||
**beam_kwargs,
|
||||
**logits_process_kwargs,
|
||||
)
|
||||
@@ -936,6 +951,7 @@ class GenerationTesterMixin:
|
||||
output_ids_generate = model.generate(
|
||||
do_sample=False,
|
||||
max_length=max_length,
|
||||
remove_invalid_values=True,
|
||||
)
|
||||
|
||||
self.assertIsNotNone(output_ids_generate)
|
||||
@@ -1309,7 +1325,12 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
|
||||
|
||||
outputs = bart_model.generate(
|
||||
input_ids, num_beams=4, num_return_sequences=2, num_beam_groups=4, diversity_penalty=2.0
|
||||
input_ids,
|
||||
num_beams=4,
|
||||
num_return_sequences=2,
|
||||
num_beam_groups=4,
|
||||
diversity_penalty=2.0,
|
||||
remove_invalid_values=True,
|
||||
)
|
||||
|
||||
generated_text = bart_tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
@@ -1359,6 +1380,7 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
decoder_start_token_id=bart_model.config.decoder_start_token_id,
|
||||
bos_token_id=bart_model.config.bos_token_id,
|
||||
)
|
||||
with torch.no_grad():
|
||||
bart_model.sample(
|
||||
input_ids,
|
||||
max_length=max_length,
|
||||
@@ -1463,6 +1485,7 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
|
||||
# Sample
|
||||
with self.assertWarns(UserWarning):
|
||||
with torch.no_grad():
|
||||
bart_model.sample(
|
||||
input_ids,
|
||||
max_length=max_length,
|
||||
@@ -1480,6 +1503,7 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
device=torch_device,
|
||||
)
|
||||
with self.assertWarns(UserWarning):
|
||||
with torch.no_grad():
|
||||
bart_model.beam_search(
|
||||
input_ids,
|
||||
num_beams=num_beams,
|
||||
|
||||
Reference in New Issue
Block a user