[Generate] Add save mode logits processor to remove nans and infs if necessary (#10769)
* push * finish * finish * make fix copies * change name
This commit is contained in:
committed by
GitHub
parent
9f8fa4e973
commit
77bf3fe787
@@ -151,6 +151,16 @@ generation.
|
|||||||
.. autoclass:: transformers.HammingDiversityLogitsProcessor
|
.. autoclass:: transformers.HammingDiversityLogitsProcessor
|
||||||
:members: __call__
|
:members: __call__
|
||||||
|
|
||||||
|
.. autoclass:: transformers.ForcedBOSTokenLogitsProcessor
|
||||||
|
:members: __call__
|
||||||
|
|
||||||
|
.. autoclass:: transformers.ForcedEOSTokenLogitsProcessor
|
||||||
|
:members: __call__
|
||||||
|
|
||||||
|
.. autoclass:: transformers.InfNanRemoveLogitsProcessor
|
||||||
|
:members: __call__
|
||||||
|
|
||||||
|
|
||||||
StoppingCriteria
|
StoppingCriteria
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
|||||||
@@ -369,7 +369,10 @@ if is_torch_available():
|
|||||||
]
|
]
|
||||||
_import_structure["generation_beam_search"] = ["BeamScorer", "BeamSearchScorer"]
|
_import_structure["generation_beam_search"] = ["BeamScorer", "BeamSearchScorer"]
|
||||||
_import_structure["generation_logits_process"] = [
|
_import_structure["generation_logits_process"] = [
|
||||||
|
"ForcedBOSTokenLogitsProcessor",
|
||||||
|
"ForcedEOSTokenLogitsProcessor",
|
||||||
"HammingDiversityLogitsProcessor",
|
"HammingDiversityLogitsProcessor",
|
||||||
|
"InfNanRemoveLogitsProcessor",
|
||||||
"LogitsProcessor",
|
"LogitsProcessor",
|
||||||
"LogitsProcessorList",
|
"LogitsProcessorList",
|
||||||
"LogitsWarper",
|
"LogitsWarper",
|
||||||
@@ -1560,7 +1563,10 @@ if TYPE_CHECKING:
|
|||||||
)
|
)
|
||||||
from .generation_beam_search import BeamScorer, BeamSearchScorer
|
from .generation_beam_search import BeamScorer, BeamSearchScorer
|
||||||
from .generation_logits_process import (
|
from .generation_logits_process import (
|
||||||
|
ForcedBOSTokenLogitsProcessor,
|
||||||
|
ForcedEOSTokenLogitsProcessor,
|
||||||
HammingDiversityLogitsProcessor,
|
HammingDiversityLogitsProcessor,
|
||||||
|
InfNanRemoveLogitsProcessor,
|
||||||
LogitsProcessor,
|
LogitsProcessor,
|
||||||
LogitsProcessorList,
|
LogitsProcessorList,
|
||||||
LogitsWarper,
|
LogitsWarper,
|
||||||
|
|||||||
@@ -134,6 +134,9 @@ class PretrainedConfig(object):
|
|||||||
<../model_doc/mbart>` where the first generated token needs to be the target language token.
|
<../model_doc/mbart>` where the first generated token needs to be the target language token.
|
||||||
- **forced_eos_token_id** (:obj:`int`, `optional`) -- The id of the token to force as the last generated token
|
- **forced_eos_token_id** (:obj:`int`, `optional`) -- The id of the token to force as the last generated token
|
||||||
when :obj:`max_length` is reached.
|
when :obj:`max_length` is reached.
|
||||||
|
- **remove_invalid_values** (:obj:`bool`, `optional`) -- Whether to remove possible `nan` and `inf` outputs of
|
||||||
|
the model to prevent the generation method to crash. Note that using ``remove_invalid_values`` can slow down
|
||||||
|
generation.
|
||||||
|
|
||||||
|
|
||||||
Parameters for fine-tuning tasks
|
Parameters for fine-tuning tasks
|
||||||
@@ -219,6 +222,7 @@ class PretrainedConfig(object):
|
|||||||
self.return_dict_in_generate = kwargs.pop("return_dict_in_generate", False)
|
self.return_dict_in_generate = kwargs.pop("return_dict_in_generate", False)
|
||||||
self.forced_bos_token_id = kwargs.pop("forced_bos_token_id", None)
|
self.forced_bos_token_id = kwargs.pop("forced_bos_token_id", None)
|
||||||
self.forced_eos_token_id = kwargs.pop("forced_eos_token_id", None)
|
self.forced_eos_token_id = kwargs.pop("forced_eos_token_id", None)
|
||||||
|
self.remove_invalid_values = kwargs.pop("remove_invalid_values", False)
|
||||||
|
|
||||||
# Fine-tuning task arguments
|
# Fine-tuning task arguments
|
||||||
self.architectures = kwargs.pop("architectures", None)
|
self.architectures = kwargs.pop("architectures", None)
|
||||||
|
|||||||
@@ -566,3 +566,20 @@ class ForcedEOSTokenLogitsProcessor(LogitsProcessor):
|
|||||||
scores[:, [i for i in range(num_tokens) if i != self.eos_token_id]] = -float("inf")
|
scores[:, [i for i in range(num_tokens) if i != self.eos_token_id]] = -float("inf")
|
||||||
scores[:, self.eos_token_id] = 0
|
scores[:, self.eos_token_id] = 0
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
|
|
||||||
|
class InfNanRemoveLogitsProcessor(LogitsProcessor):
|
||||||
|
r"""
|
||||||
|
:class:`~transformers.LogitsProcessor` that removes all :obj:`nan` and :obj:`inf` values to avoid the generation
|
||||||
|
method to fail. Note that using the logits processor should only be used if necessary since it can slow down the
|
||||||
|
generation method. :obj:`max_length` is reached.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
|
# set all nan values to 0.0
|
||||||
|
scores[scores != scores] = 0.0
|
||||||
|
|
||||||
|
# set all inf values to max possible value
|
||||||
|
scores[scores == float("inf")] = torch.finfo(scores.dtype).max
|
||||||
|
|
||||||
|
return scores
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ from .generation_logits_process import (
|
|||||||
ForcedBOSTokenLogitsProcessor,
|
ForcedBOSTokenLogitsProcessor,
|
||||||
ForcedEOSTokenLogitsProcessor,
|
ForcedEOSTokenLogitsProcessor,
|
||||||
HammingDiversityLogitsProcessor,
|
HammingDiversityLogitsProcessor,
|
||||||
|
InfNanRemoveLogitsProcessor,
|
||||||
LogitsProcessorList,
|
LogitsProcessorList,
|
||||||
MinLengthLogitsProcessor,
|
MinLengthLogitsProcessor,
|
||||||
NoBadWordsLogitsProcessor,
|
NoBadWordsLogitsProcessor,
|
||||||
@@ -581,6 +582,7 @@ class GenerationMixin:
|
|||||||
num_beams: int,
|
num_beams: int,
|
||||||
num_beam_groups: int,
|
num_beam_groups: int,
|
||||||
diversity_penalty: float,
|
diversity_penalty: float,
|
||||||
|
remove_invalid_values: bool,
|
||||||
) -> LogitsProcessorList:
|
) -> LogitsProcessorList:
|
||||||
"""
|
"""
|
||||||
This class returns a :obj:`~transformers.LogitsProcessorList` list object that contains all relevant
|
This class returns a :obj:`~transformers.LogitsProcessorList` list object that contains all relevant
|
||||||
@@ -607,6 +609,9 @@ class GenerationMixin:
|
|||||||
forced_eos_token_id = (
|
forced_eos_token_id = (
|
||||||
forced_eos_token_id if forced_eos_token_id is not None else self.config.forced_eos_token_id
|
forced_eos_token_id if forced_eos_token_id is not None else self.config.forced_eos_token_id
|
||||||
)
|
)
|
||||||
|
remove_invalid_values = (
|
||||||
|
remove_invalid_values if remove_invalid_values is not None else self.config.remove_invalid_values
|
||||||
|
)
|
||||||
# instantiate processors list
|
# instantiate processors list
|
||||||
processors = LogitsProcessorList()
|
processors = LogitsProcessorList()
|
||||||
|
|
||||||
@@ -639,6 +644,8 @@ class GenerationMixin:
|
|||||||
processors.append(ForcedBOSTokenLogitsProcessor(forced_bos_token_id))
|
processors.append(ForcedBOSTokenLogitsProcessor(forced_bos_token_id))
|
||||||
if forced_eos_token_id is not None:
|
if forced_eos_token_id is not None:
|
||||||
processors.append(ForcedEOSTokenLogitsProcessor(max_length, forced_eos_token_id))
|
processors.append(ForcedEOSTokenLogitsProcessor(max_length, forced_eos_token_id))
|
||||||
|
if remove_invalid_values is True:
|
||||||
|
processors.append(InfNanRemoveLogitsProcessor())
|
||||||
return processors
|
return processors
|
||||||
|
|
||||||
def _get_stopping_criteria(
|
def _get_stopping_criteria(
|
||||||
@@ -687,6 +694,7 @@ class GenerationMixin:
|
|||||||
return_dict_in_generate: Optional[bool] = None,
|
return_dict_in_generate: Optional[bool] = None,
|
||||||
forced_bos_token_id: Optional[int] = None,
|
forced_bos_token_id: Optional[int] = None,
|
||||||
forced_eos_token_id: Optional[int] = None,
|
forced_eos_token_id: Optional[int] = None,
|
||||||
|
remove_invalid_values: Optional[bool] = None,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
) -> Union[GreedySearchOutput, SampleOutput, BeamSearchOutput, BeamSampleOutput, torch.LongTensor]:
|
) -> Union[GreedySearchOutput, SampleOutput, BeamSearchOutput, BeamSampleOutput, torch.LongTensor]:
|
||||||
r"""
|
r"""
|
||||||
@@ -789,6 +797,9 @@ class GenerationMixin:
|
|||||||
needs to be the target language token.
|
needs to be the target language token.
|
||||||
forced_eos_token_id (:obj:`int`, `optional`):
|
forced_eos_token_id (:obj:`int`, `optional`):
|
||||||
The id of the token to force as the last generated token when :obj:`max_length` is reached.
|
The id of the token to force as the last generated token when :obj:`max_length` is reached.
|
||||||
|
remove_invalid_values (:obj:`bool`, `optional`):
|
||||||
|
Whether to remove possible `nan` and `inf` outputs of the model to prevent the generation method to
|
||||||
|
crash. Note that using ``remove_invalid_values`` can slow down generation.
|
||||||
|
|
||||||
model_kwargs:
|
model_kwargs:
|
||||||
Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model. If the
|
Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model. If the
|
||||||
@@ -965,6 +976,7 @@ class GenerationMixin:
|
|||||||
num_beams=num_beams,
|
num_beams=num_beams,
|
||||||
num_beam_groups=num_beam_groups,
|
num_beam_groups=num_beam_groups,
|
||||||
diversity_penalty=diversity_penalty,
|
diversity_penalty=diversity_penalty,
|
||||||
|
remove_invalid_values=remove_invalid_values,
|
||||||
)
|
)
|
||||||
|
|
||||||
stopping_criteria = self._get_stopping_criteria(
|
stopping_criteria = self._get_stopping_criteria(
|
||||||
@@ -1511,6 +1523,7 @@ class GenerationMixin:
|
|||||||
|
|
||||||
# sample
|
# sample
|
||||||
probs = F.softmax(next_token_scores, dim=-1)
|
probs = F.softmax(next_token_scores, dim=-1)
|
||||||
|
|
||||||
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
||||||
|
|
||||||
# add code that transfomers next_tokens to tokens_to_add
|
# add code that transfomers next_tokens to tokens_to_add
|
||||||
@@ -2026,6 +2039,7 @@ class GenerationMixin:
|
|||||||
next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)
|
next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)
|
||||||
|
|
||||||
probs = F.softmax(next_token_scores, dim=-1)
|
probs = F.softmax(next_token_scores, dim=-1)
|
||||||
|
|
||||||
next_tokens = torch.multinomial(probs, num_samples=2 * num_beams)
|
next_tokens = torch.multinomial(probs, num_samples=2 * num_beams)
|
||||||
next_token_scores = torch.gather(next_token_scores, -1, next_tokens)
|
next_token_scores = torch.gather(next_token_scores, -1, next_tokens)
|
||||||
|
|
||||||
|
|||||||
@@ -1316,6 +1316,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
|||||||
prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]] = None,
|
prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]] = None,
|
||||||
forced_bos_token_id: Optional[int] = None,
|
forced_bos_token_id: Optional[int] = None,
|
||||||
forced_eos_token_id: Optional[int] = None,
|
forced_eos_token_id: Optional[int] = None,
|
||||||
|
remove_invalid_values: Optional[bool] = None,
|
||||||
**model_kwargs
|
**model_kwargs
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -1412,6 +1413,9 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
|||||||
needs to be the target language token.
|
needs to be the target language token.
|
||||||
forced_eos_token_id (:obj:`int`, `optional`):
|
forced_eos_token_id (:obj:`int`, `optional`):
|
||||||
The id of the token to force as the last generated token when :obj:`max_length` is reached.
|
The id of the token to force as the last generated token when :obj:`max_length` is reached.
|
||||||
|
remove_invalid_values (:obj:`bool`, `optional`):
|
||||||
|
Whether to remove possible `nan` and `inf` outputs of the model to prevent the generation method to
|
||||||
|
crash. Note that using ``remove_invalid_values`` can slow down generation.
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
:obj:`torch.LongTensor` of shape :obj:`(batch_size * num_return_sequences, sequence_length)`: The generated
|
:obj:`torch.LongTensor` of shape :obj:`(batch_size * num_return_sequences, sequence_length)`: The generated
|
||||||
@@ -1435,6 +1439,9 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
|||||||
if decoder_start_token_id is not None
|
if decoder_start_token_id is not None
|
||||||
else self.config.generator.decoder_start_token_id
|
else self.config.generator.decoder_start_token_id
|
||||||
)
|
)
|
||||||
|
remove_invalid_values = (
|
||||||
|
remove_invalid_values if remove_invalid_values is not None else self.config.remove_invalid_values
|
||||||
|
)
|
||||||
|
|
||||||
# retrieve docs
|
# retrieve docs
|
||||||
if self.retriever is not None and context_input_ids is None:
|
if self.retriever is not None and context_input_ids is None:
|
||||||
@@ -1515,6 +1522,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
|||||||
num_beams=num_beams,
|
num_beams=num_beams,
|
||||||
num_beam_groups=num_beam_groups,
|
num_beam_groups=num_beam_groups,
|
||||||
diversity_penalty=diversity_penalty,
|
diversity_penalty=diversity_penalty,
|
||||||
|
remove_invalid_values=remove_invalid_values,
|
||||||
)
|
)
|
||||||
|
|
||||||
if num_beams == 1:
|
if num_beams == 1:
|
||||||
|
|||||||
@@ -123,11 +123,26 @@ class BeamSearchScorer:
|
|||||||
requires_pytorch(self)
|
requires_pytorch(self)
|
||||||
|
|
||||||
|
|
||||||
|
class ForcedBOSTokenLogitsProcessor:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_pytorch(self)
|
||||||
|
|
||||||
|
|
||||||
|
class ForcedEOSTokenLogitsProcessor:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_pytorch(self)
|
||||||
|
|
||||||
|
|
||||||
class HammingDiversityLogitsProcessor:
|
class HammingDiversityLogitsProcessor:
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
requires_pytorch(self)
|
requires_pytorch(self)
|
||||||
|
|
||||||
|
|
||||||
|
class InfNanRemoveLogitsProcessor:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_pytorch(self)
|
||||||
|
|
||||||
|
|
||||||
class LogitsProcessor:
|
class LogitsProcessor:
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
requires_pytorch(self)
|
requires_pytorch(self)
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ if is_torch_available():
|
|||||||
ForcedBOSTokenLogitsProcessor,
|
ForcedBOSTokenLogitsProcessor,
|
||||||
ForcedEOSTokenLogitsProcessor,
|
ForcedEOSTokenLogitsProcessor,
|
||||||
HammingDiversityLogitsProcessor,
|
HammingDiversityLogitsProcessor,
|
||||||
|
InfNanRemoveLogitsProcessor,
|
||||||
LogitsProcessorList,
|
LogitsProcessorList,
|
||||||
MinLengthLogitsProcessor,
|
MinLengthLogitsProcessor,
|
||||||
NoBadWordsLogitsProcessor,
|
NoBadWordsLogitsProcessor,
|
||||||
@@ -436,3 +437,24 @@ class LogitsProcessorTest(unittest.TestCase):
|
|||||||
scores = self._get_uniform_logits(batch_size, vocab_size)
|
scores = self._get_uniform_logits(batch_size, vocab_size)
|
||||||
scores = logits_processor(input_ids, scores)
|
scores = logits_processor(input_ids, scores)
|
||||||
self.assertFalse(torch.isinf(scores).any())
|
self.assertFalse(torch.isinf(scores).any())
|
||||||
|
|
||||||
|
def test_remove_nan_inf_logits_processor(self):
|
||||||
|
scores = torch.tensor(
|
||||||
|
[[0.0, 0.7, 0.8, float("nan")], [0.1, float("inf"), 0.3, float("-inf")]], device=torch_device
|
||||||
|
)
|
||||||
|
input_ids = ids_tensor((2, 4), vocab_size=20)
|
||||||
|
|
||||||
|
logits_processor = InfNanRemoveLogitsProcessor()
|
||||||
|
|
||||||
|
scores = logits_processor(input_ids, scores)
|
||||||
|
|
||||||
|
self.assertTrue(
|
||||||
|
torch.allclose(
|
||||||
|
scores,
|
||||||
|
torch.tensor(
|
||||||
|
[[0.0, 0.7, 0.8, 0.0], [0.1, torch.finfo(scores.dtype).max, 0.3, float("-inf")]],
|
||||||
|
device=torch_device,
|
||||||
|
),
|
||||||
|
atol=1e-6,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ if is_torch_available():
|
|||||||
ForcedBOSTokenLogitsProcessor,
|
ForcedBOSTokenLogitsProcessor,
|
||||||
ForcedEOSTokenLogitsProcessor,
|
ForcedEOSTokenLogitsProcessor,
|
||||||
HammingDiversityLogitsProcessor,
|
HammingDiversityLogitsProcessor,
|
||||||
|
InfNanRemoveLogitsProcessor,
|
||||||
LogitsProcessorList,
|
LogitsProcessorList,
|
||||||
MinLengthLogitsProcessor,
|
MinLengthLogitsProcessor,
|
||||||
NoBadWordsLogitsProcessor,
|
NoBadWordsLogitsProcessor,
|
||||||
@@ -229,6 +230,7 @@ class GenerationTesterMixin:
|
|||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
output_scores=output_scores,
|
output_scores=output_scores,
|
||||||
return_dict_in_generate=return_dict_in_generate,
|
return_dict_in_generate=return_dict_in_generate,
|
||||||
|
remove_invalid_values=True,
|
||||||
**logits_process_kwargs,
|
**logits_process_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -284,6 +286,7 @@ class GenerationTesterMixin:
|
|||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict_in_generate=return_dict_in_generate,
|
return_dict_in_generate=return_dict_in_generate,
|
||||||
|
remove_invalid_values=True,
|
||||||
**logits_warper_kwargs,
|
**logits_warper_kwargs,
|
||||||
**process_kwargs,
|
**process_kwargs,
|
||||||
)
|
)
|
||||||
@@ -305,19 +308,23 @@ class GenerationTesterMixin:
|
|||||||
attention_mask_clone = attention_mask.repeat_interleave(num_return_sequences, dim=0)
|
attention_mask_clone = attention_mask.repeat_interleave(num_return_sequences, dim=0)
|
||||||
input_ids_clone = input_ids.repeat_interleave(num_return_sequences, dim=0)
|
input_ids_clone = input_ids.repeat_interleave(num_return_sequences, dim=0)
|
||||||
|
|
||||||
|
# prevent flaky generation test failures
|
||||||
|
logits_processor.append(InfNanRemoveLogitsProcessor())
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
output_sample = model.sample(
|
with torch.no_grad():
|
||||||
input_ids_clone,
|
output_sample = model.sample(
|
||||||
attention_mask=attention_mask_clone,
|
input_ids_clone,
|
||||||
max_length=max_length,
|
attention_mask=attention_mask_clone,
|
||||||
logits_processor=logits_processor,
|
max_length=max_length,
|
||||||
logits_warper=logits_warper,
|
logits_processor=logits_processor,
|
||||||
output_scores=output_scores,
|
logits_warper=logits_warper,
|
||||||
output_attentions=output_attentions,
|
output_scores=output_scores,
|
||||||
output_hidden_states=output_hidden_states,
|
output_attentions=output_attentions,
|
||||||
return_dict_in_generate=return_dict_in_generate,
|
output_hidden_states=output_hidden_states,
|
||||||
**kwargs,
|
return_dict_in_generate=return_dict_in_generate,
|
||||||
)
|
**kwargs,
|
||||||
|
)
|
||||||
return output_sample, output_generate
|
return output_sample, output_generate
|
||||||
|
|
||||||
def _beam_search_generate(
|
def _beam_search_generate(
|
||||||
@@ -344,6 +351,7 @@ class GenerationTesterMixin:
|
|||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict_in_generate=return_dict_in_generate,
|
return_dict_in_generate=return_dict_in_generate,
|
||||||
|
remove_invalid_values=True,
|
||||||
**beam_kwargs,
|
**beam_kwargs,
|
||||||
**logits_process_kwargs,
|
**logits_process_kwargs,
|
||||||
)
|
)
|
||||||
@@ -406,6 +414,7 @@ class GenerationTesterMixin:
|
|||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict_in_generate=return_dict_in_generate,
|
return_dict_in_generate=return_dict_in_generate,
|
||||||
|
remove_invalid_values=True,
|
||||||
**beam_kwargs,
|
**beam_kwargs,
|
||||||
**logits_warper_kwargs,
|
**logits_warper_kwargs,
|
||||||
)
|
)
|
||||||
@@ -424,6 +433,10 @@ class GenerationTesterMixin:
|
|||||||
else:
|
else:
|
||||||
attention_mask = attention_mask.repeat_interleave(beam_scorer.num_beams * num_return_sequences, dim=0)
|
attention_mask = attention_mask.repeat_interleave(beam_scorer.num_beams * num_return_sequences, dim=0)
|
||||||
|
|
||||||
|
# prevent flaky generation test failures
|
||||||
|
logits_processor = LogitsProcessorList()
|
||||||
|
logits_processor.append(InfNanRemoveLogitsProcessor())
|
||||||
|
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
output_beam_sample = model.beam_sample(
|
output_beam_sample = model.beam_sample(
|
||||||
@@ -432,6 +445,7 @@ class GenerationTesterMixin:
|
|||||||
max_length=max_length,
|
max_length=max_length,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
logits_warper=logits_warper,
|
logits_warper=logits_warper,
|
||||||
|
logits_processor=logits_processor,
|
||||||
output_scores=output_scores,
|
output_scores=output_scores,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
@@ -465,6 +479,7 @@ class GenerationTesterMixin:
|
|||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict_in_generate=return_dict_in_generate,
|
return_dict_in_generate=return_dict_in_generate,
|
||||||
|
remove_invalid_values=True,
|
||||||
**beam_kwargs,
|
**beam_kwargs,
|
||||||
**logits_process_kwargs,
|
**logits_process_kwargs,
|
||||||
)
|
)
|
||||||
@@ -936,6 +951,7 @@ class GenerationTesterMixin:
|
|||||||
output_ids_generate = model.generate(
|
output_ids_generate = model.generate(
|
||||||
do_sample=False,
|
do_sample=False,
|
||||||
max_length=max_length,
|
max_length=max_length,
|
||||||
|
remove_invalid_values=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertIsNotNone(output_ids_generate)
|
self.assertIsNotNone(output_ids_generate)
|
||||||
@@ -1309,7 +1325,12 @@ class GenerationIntegrationTests(unittest.TestCase):
|
|||||||
input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
|
input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
|
||||||
|
|
||||||
outputs = bart_model.generate(
|
outputs = bart_model.generate(
|
||||||
input_ids, num_beams=4, num_return_sequences=2, num_beam_groups=4, diversity_penalty=2.0
|
input_ids,
|
||||||
|
num_beams=4,
|
||||||
|
num_return_sequences=2,
|
||||||
|
num_beam_groups=4,
|
||||||
|
diversity_penalty=2.0,
|
||||||
|
remove_invalid_values=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
generated_text = bart_tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
generated_text = bart_tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||||
@@ -1359,13 +1380,14 @@ class GenerationIntegrationTests(unittest.TestCase):
|
|||||||
decoder_start_token_id=bart_model.config.decoder_start_token_id,
|
decoder_start_token_id=bart_model.config.decoder_start_token_id,
|
||||||
bos_token_id=bart_model.config.bos_token_id,
|
bos_token_id=bart_model.config.bos_token_id,
|
||||||
)
|
)
|
||||||
bart_model.sample(
|
with torch.no_grad():
|
||||||
input_ids,
|
bart_model.sample(
|
||||||
max_length=max_length,
|
input_ids,
|
||||||
pad_token_id=bart_model.config.pad_token_id,
|
max_length=max_length,
|
||||||
eos_token_id=bart_model.config.eos_token_id,
|
pad_token_id=bart_model.config.pad_token_id,
|
||||||
**model_kwargs,
|
eos_token_id=bart_model.config.eos_token_id,
|
||||||
)
|
**model_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
def test_max_length_backward_compat_beam_search(self):
|
def test_max_length_backward_compat_beam_search(self):
|
||||||
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
||||||
@@ -1463,14 +1485,15 @@ class GenerationIntegrationTests(unittest.TestCase):
|
|||||||
|
|
||||||
# Sample
|
# Sample
|
||||||
with self.assertWarns(UserWarning):
|
with self.assertWarns(UserWarning):
|
||||||
bart_model.sample(
|
with torch.no_grad():
|
||||||
input_ids,
|
bart_model.sample(
|
||||||
max_length=max_length,
|
input_ids,
|
||||||
stopping_criteria=stopping_criteria,
|
max_length=max_length,
|
||||||
pad_token_id=bart_model.config.pad_token_id,
|
stopping_criteria=stopping_criteria,
|
||||||
eos_token_id=bart_model.config.eos_token_id,
|
pad_token_id=bart_model.config.pad_token_id,
|
||||||
**model_kwargs,
|
eos_token_id=bart_model.config.eos_token_id,
|
||||||
)
|
**model_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
# Beam
|
# Beam
|
||||||
beam_scorer = BeamSearchScorer(
|
beam_scorer = BeamSearchScorer(
|
||||||
@@ -1480,14 +1503,15 @@ class GenerationIntegrationTests(unittest.TestCase):
|
|||||||
device=torch_device,
|
device=torch_device,
|
||||||
)
|
)
|
||||||
with self.assertWarns(UserWarning):
|
with self.assertWarns(UserWarning):
|
||||||
bart_model.beam_search(
|
with torch.no_grad():
|
||||||
input_ids,
|
bart_model.beam_search(
|
||||||
num_beams=num_beams,
|
input_ids,
|
||||||
stopping_criteria=stopping_criteria,
|
num_beams=num_beams,
|
||||||
max_length=max_length,
|
stopping_criteria=stopping_criteria,
|
||||||
beam_scorer=beam_scorer,
|
max_length=max_length,
|
||||||
**model_kwargs,
|
beam_scorer=beam_scorer,
|
||||||
)
|
**model_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
# Grouped beam search
|
# Grouped beam search
|
||||||
diverse_beam_scorer = BeamSearchScorer(
|
diverse_beam_scorer = BeamSearchScorer(
|
||||||
|
|||||||
Reference in New Issue
Block a user