Removed max_length from being mandatory within generate. (#11314)
* Removed `max_length` from being mandatory within `generate`. - Moving on to fully using `StoppingCriteria` for `greedy` and `sample` modes. - `max_length` still used for `beam_search` and `group_beam_search` (Follow up PR) - Fixes a bug with MaxLengthStoppingCriteria (we should stop as soon a we hit the max_length, the comparison needs to be or equal, that affects the tests). - Added options to use `logits_processor` and `stopping_criteria` directly within `generate` function (so some users can define their own `logits_processor` and `stopping_criteria`). - Modified the backward compat tests to make sure we issue a warning. * Fix `max_length` argument in `generate`. * Moving validate to being functional. - Renamed `smax_length` to `stoppping_max_length`. * Removing `logits_processor` and `stopping_criteria` from `generate` arguments. * Deepcopy. * Fix global variable name.
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
import time
|
||||
import warnings
|
||||
from abc import ABC
|
||||
from copy import deepcopy
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
@@ -8,7 +9,7 @@ import torch
|
||||
from .file_utils import add_start_docstrings
|
||||
|
||||
|
||||
LOGITS_PROCESSOR_INPUTS_DOCSTRING = r"""
|
||||
STOPPING_CRITERIA_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
|
||||
Indices of input sequence tokens in the vocabulary.
|
||||
@@ -33,7 +34,7 @@ LOGITS_PROCESSOR_INPUTS_DOCSTRING = r"""
|
||||
class StoppingCriteria(ABC):
|
||||
"""Abstract base class for all stopping criteria that can be applied during generation."""
|
||||
|
||||
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
||||
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
|
||||
def __call__(self, input_ids: torch.LongTensor, score: torch.FloatTensor, **kwargs) -> bool:
|
||||
raise NotImplementedError("StoppingCriteria needs to be subclassed")
|
||||
|
||||
@@ -51,9 +52,9 @@ class MaxLengthCriteria(StoppingCriteria):
|
||||
def __init__(self, max_length: int):
|
||||
self.max_length = max_length
|
||||
|
||||
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
||||
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
||||
return input_ids.shape[-1] > self.max_length
|
||||
return input_ids.shape[-1] >= self.max_length
|
||||
|
||||
|
||||
class MaxTimeCriteria(StoppingCriteria):
|
||||
@@ -73,25 +74,29 @@ class MaxTimeCriteria(StoppingCriteria):
|
||||
self.max_time = max_time
|
||||
self.initial_timestamp = time.time() if initial_timestamp is None else initial_timestamp
|
||||
|
||||
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
||||
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
||||
return time.time() - self.initial_timestamp > self.max_time
|
||||
|
||||
|
||||
class StoppingCriteriaList(list):
|
||||
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
||||
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
||||
return any(criteria(input_ids, scores) for criteria in self)
|
||||
|
||||
@property
|
||||
def max_length(self) -> Optional[int]:
|
||||
for stopping_criterium in self:
|
||||
if isinstance(stopping_criterium, MaxLengthCriteria):
|
||||
return stopping_criterium.max_length
|
||||
return None
|
||||
|
||||
def validate_stopping_criteria(stopping_criteria: StoppingCriteriaList, max_length: int):
|
||||
found = False
|
||||
for stopping_criterium in stopping_criteria:
|
||||
if isinstance(stopping_criterium, MaxLengthCriteria):
|
||||
found = True
|
||||
if stopping_criterium.max_length != max_length:
|
||||
warnings.warn(
|
||||
"You set different `max_length` for stopping criteria and `max_length` parameter", UserWarning
|
||||
)
|
||||
if not found:
|
||||
stopping_criteria.append(MaxLengthCriteria(max_length=max_length))
|
||||
|
||||
def validate_stopping_criteria(stopping_criteria: StoppingCriteriaList, max_length: int) -> StoppingCriteriaList:
|
||||
stopping_max_length = stopping_criteria.max_length
|
||||
new_stopping_criteria = deepcopy(stopping_criteria)
|
||||
if stopping_max_length is not None and stopping_max_length != max_length:
|
||||
warnings.warn("You set different `max_length` for stopping criteria and `max_length` parameter", UserWarning)
|
||||
elif stopping_max_length is None:
|
||||
new_stopping_criteria.append(MaxLengthCriteria(max_length=max_length))
|
||||
return new_stopping_criteria
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
@@ -564,6 +565,7 @@ class GenerationMixin:
|
||||
This class returns a :obj:`~transformers.LogitsProcessorList` list object that contains all relevant
|
||||
:obj:`~transformers.LogitsProcessor` instances used to modify the scores of the language model head.
|
||||
"""
|
||||
processors = LogitsProcessorList()
|
||||
|
||||
# init warp parameters
|
||||
repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.repetition_penalty
|
||||
@@ -589,7 +591,6 @@ class GenerationMixin:
|
||||
remove_invalid_values if remove_invalid_values is not None else self.config.remove_invalid_values
|
||||
)
|
||||
# instantiate processors list
|
||||
processors = LogitsProcessorList()
|
||||
|
||||
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
|
||||
# all samplers can be found in `generation_utils_samplers.py`
|
||||
@@ -629,7 +630,6 @@ class GenerationMixin:
|
||||
max_length: Optional[int],
|
||||
max_time: Optional[float],
|
||||
) -> StoppingCriteriaList:
|
||||
|
||||
stopping_criteria = StoppingCriteriaList()
|
||||
if max_length is not None:
|
||||
stopping_criteria.append(MaxLengthCriteria(max_length=max_length))
|
||||
@@ -859,9 +859,9 @@ class GenerationMixin:
|
||||
"""
|
||||
|
||||
# set init values
|
||||
max_length = max_length if max_length is not None else self.config.max_length
|
||||
num_beams = num_beams if num_beams is not None else self.config.num_beams
|
||||
num_beam_groups = num_beam_groups if num_beam_groups is not None else self.config.num_beam_groups
|
||||
max_length = max_length if max_length is not None else self.config.max_length
|
||||
do_sample = do_sample if do_sample is not None else self.config.do_sample
|
||||
num_return_sequences = (
|
||||
num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
|
||||
@@ -958,10 +958,13 @@ class GenerationMixin:
|
||||
remove_invalid_values=remove_invalid_values,
|
||||
)
|
||||
|
||||
stopping_criteria = self._get_stopping_criteria(
|
||||
max_length=max_length,
|
||||
max_time=max_time,
|
||||
)
|
||||
stopping_criteria = self._get_stopping_criteria(max_length=max_length, max_time=max_time)
|
||||
if max_length is not None:
|
||||
warnings.warn(
|
||||
"`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.",
|
||||
UserWarning,
|
||||
)
|
||||
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
|
||||
|
||||
if is_greedy_gen_mode:
|
||||
if num_return_sequences > 1:
|
||||
@@ -974,7 +977,6 @@ class GenerationMixin:
|
||||
input_ids,
|
||||
logits_processor=logits_processor,
|
||||
stopping_criteria=stopping_criteria,
|
||||
max_length=max_length,
|
||||
pad_token_id=pad_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
output_scores=output_scores,
|
||||
@@ -1003,7 +1005,6 @@ class GenerationMixin:
|
||||
logits_processor=logits_processor,
|
||||
logits_warper=logits_warper,
|
||||
stopping_criteria=stopping_criteria,
|
||||
max_length=max_length,
|
||||
pad_token_id=pad_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
output_scores=output_scores,
|
||||
@@ -1021,9 +1022,12 @@ class GenerationMixin:
|
||||
if num_return_sequences > num_beams:
|
||||
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")
|
||||
|
||||
if stopping_criteria.max_length is None:
|
||||
raise ValueError("`max_length` needs to be a stopping_criteria for now.")
|
||||
|
||||
beam_scorer = BeamSearchScorer(
|
||||
batch_size=batch_size,
|
||||
max_length=max_length,
|
||||
max_length=stopping_criteria.max_length,
|
||||
num_beams=num_beams,
|
||||
device=self.device,
|
||||
length_penalty=length_penalty,
|
||||
@@ -1039,7 +1043,6 @@ class GenerationMixin:
|
||||
beam_scorer,
|
||||
logits_processor=logits_processor,
|
||||
stopping_criteria=stopping_criteria,
|
||||
max_length=max_length,
|
||||
pad_token_id=pad_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
output_scores=output_scores,
|
||||
@@ -1056,9 +1059,11 @@ class GenerationMixin:
|
||||
batch_size = input_ids.shape[0] * num_return_sequences
|
||||
|
||||
length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
|
||||
if stopping_criteria.max_length is None:
|
||||
raise ValueError("`max_length` needs to be a stopping_criteria for now.")
|
||||
beam_scorer = BeamSearchScorer(
|
||||
batch_size=batch_size,
|
||||
max_length=max_length,
|
||||
max_length=stopping_criteria.max_length,
|
||||
num_beams=num_beams,
|
||||
device=self.device,
|
||||
length_penalty=length_penalty,
|
||||
@@ -1079,7 +1084,6 @@ class GenerationMixin:
|
||||
logits_processor=logits_processor,
|
||||
logits_warper=logits_warper,
|
||||
stopping_criteria=stopping_criteria,
|
||||
max_length=max_length,
|
||||
pad_token_id=pad_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
output_scores=output_scores,
|
||||
@@ -1100,10 +1104,13 @@ class GenerationMixin:
|
||||
if num_beams % num_beam_groups != 0:
|
||||
raise ValueError("`num_beams` should be divisible by `num_beam_groups` for group beam search.")
|
||||
|
||||
if stopping_criteria.max_length is None:
|
||||
raise ValueError("`max_length` needs to be a stopping_criteria for now.")
|
||||
|
||||
diverse_beam_scorer = BeamSearchScorer(
|
||||
batch_size=batch_size,
|
||||
max_length=max_length,
|
||||
num_beams=num_beams,
|
||||
max_length=stopping_criteria.max_length,
|
||||
device=self.device,
|
||||
length_penalty=length_penalty,
|
||||
do_early_stopping=early_stopping,
|
||||
@@ -1119,7 +1126,6 @@ class GenerationMixin:
|
||||
diverse_beam_scorer,
|
||||
logits_processor=logits_processor,
|
||||
stopping_criteria=stopping_criteria,
|
||||
max_length=max_length,
|
||||
pad_token_id=pad_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
output_scores=output_scores,
|
||||
@@ -1160,7 +1166,8 @@ class GenerationMixin:
|
||||
:class:`~transformers.StoppingCriteria` used to tell if the generation loop should stop.
|
||||
|
||||
max_length (:obj:`int`, `optional`, defaults to 20):
|
||||
The maximum length of the sequence to be generated.
|
||||
**DEPRECATED**. Use :obj:`logits_processor` or :obj:`stopping_criteria` directly to cap the number of
|
||||
generated tokens. The maximum length of the sequence to be generated.
|
||||
pad_token_id (:obj:`int`, `optional`):
|
||||
The id of the `padding` token.
|
||||
eos_token_id (:obj:`int`, `optional`):
|
||||
@@ -1220,8 +1227,12 @@ class GenerationMixin:
|
||||
# init values
|
||||
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
||||
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
||||
max_length = max_length if max_length is not None else self.config.max_length
|
||||
validate_stopping_criteria(stopping_criteria, max_length)
|
||||
if max_length is not None:
|
||||
warnings.warn(
|
||||
"`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.",
|
||||
UserWarning,
|
||||
)
|
||||
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
|
||||
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
|
||||
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
||||
output_scores = output_scores if output_scores is not None else self.config.output_scores
|
||||
@@ -1251,7 +1262,7 @@ class GenerationMixin:
|
||||
cur_len = input_ids.shape[-1]
|
||||
|
||||
this_peer_finished = False # used by synced_gpus only
|
||||
while cur_len < max_length:
|
||||
while True:
|
||||
|
||||
if synced_gpus:
|
||||
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
|
||||
@@ -1384,7 +1395,8 @@ class GenerationMixin:
|
||||
:class:`~transformers.LogitsWarper` used to warp the prediction score distribution of the language
|
||||
modeling head applied before multinomial sampling at each generation step.
|
||||
max_length (:obj:`int`, `optional`, defaults to 20):
|
||||
The maximum length of the sequence to be generated.
|
||||
**DEPRECATED**. Use :obj:`logits_processor` or :obj:`stopping_criteria` directly to cap the number of
|
||||
generated tokens. The maximum length of the sequence to be generated.
|
||||
pad_token_id (:obj:`int`, `optional`):
|
||||
The id of the `padding` token.
|
||||
eos_token_id (:obj:`int`, `optional`):
|
||||
@@ -1452,8 +1464,12 @@ class GenerationMixin:
|
||||
# init values
|
||||
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
||||
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
||||
max_length = max_length if max_length is not None else self.config.max_length
|
||||
validate_stopping_criteria(stopping_criteria, max_length)
|
||||
if max_length is not None:
|
||||
warnings.warn(
|
||||
"`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.",
|
||||
UserWarning,
|
||||
)
|
||||
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
|
||||
logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
|
||||
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
|
||||
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
||||
@@ -1485,7 +1501,7 @@ class GenerationMixin:
|
||||
|
||||
this_peer_finished = False # used by synced_gpus only
|
||||
# auto-regressive generation
|
||||
while cur_len < max_length:
|
||||
while True:
|
||||
|
||||
if synced_gpus:
|
||||
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
|
||||
@@ -1620,7 +1636,8 @@ class GenerationMixin:
|
||||
An instance of :class:`~transformers.StoppingCriteriaList`. List of instances of class derived from
|
||||
:class:`~transformers.StoppingCriteria` used to tell if the generation loop should stop.
|
||||
max_length (:obj:`int`, `optional`, defaults to 20):
|
||||
The maximum length of the sequence to be generated.
|
||||
**DEPRECATED**. Use :obj:`logits_processor` or :obj:`stopping_criteria` directly to cap the number of
|
||||
generated tokens. The maximum length of the sequence to be generated.
|
||||
pad_token_id (:obj:`int`, `optional`):
|
||||
The id of the `padding` token.
|
||||
eos_token_id (:obj:`int`, `optional`):
|
||||
@@ -1700,8 +1717,14 @@ class GenerationMixin:
|
||||
# init values
|
||||
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
||||
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
||||
max_length = max_length if max_length is not None else self.config.max_length
|
||||
validate_stopping_criteria(stopping_criteria, max_length)
|
||||
if max_length is not None:
|
||||
warnings.warn(
|
||||
"`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.",
|
||||
UserWarning,
|
||||
)
|
||||
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
|
||||
if len(stopping_criteria) == 0:
|
||||
warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning)
|
||||
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
|
||||
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
||||
output_scores = output_scores if output_scores is not None else self.config.output_scores
|
||||
@@ -1740,7 +1763,7 @@ class GenerationMixin:
|
||||
beam_scores = beam_scores.view((batch_size * num_beams,))
|
||||
|
||||
this_peer_finished = False # used by synced_gpus only
|
||||
while cur_len < max_length:
|
||||
while True:
|
||||
|
||||
if synced_gpus:
|
||||
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
|
||||
@@ -1770,7 +1793,7 @@ class GenerationMixin:
|
||||
# hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
|
||||
# cannot be generated both before and after the `F.log_softmax` operation.
|
||||
next_token_logits = self.adjust_logits_during_generation(
|
||||
next_token_logits, cur_len=cur_len, max_length=max_length
|
||||
next_token_logits, cur_len=cur_len, max_length=None
|
||||
)
|
||||
|
||||
next_token_scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
|
||||
@@ -1907,7 +1930,8 @@ class GenerationMixin:
|
||||
:class:`~transformers.LogitsWarper` used to warp the prediction score distribution of the language
|
||||
modeling head applied before multinomial sampling at each generation step.
|
||||
max_length (:obj:`int`, `optional`, defaults to 20):
|
||||
The maximum length of the sequence to be generated.
|
||||
**DEPRECATED**. Use :obj:`logits_processor` or :obj:`stopping_criteria` directly to cap the number of
|
||||
generated tokens. The maximum length of the sequence to be generated.
|
||||
pad_token_id (:obj:`int`, `optional`):
|
||||
The id of the `padding` token.
|
||||
eos_token_id (:obj:`int`, `optional`):
|
||||
@@ -1994,7 +2018,12 @@ class GenerationMixin:
|
||||
# init values
|
||||
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
||||
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
||||
max_length = max_length if max_length is not None else self.config.max_length
|
||||
if max_length is not None:
|
||||
warnings.warn(
|
||||
"`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.",
|
||||
UserWarning,
|
||||
)
|
||||
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
|
||||
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
|
||||
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
||||
output_scores = output_scores if output_scores is not None else self.config.output_scores
|
||||
@@ -2028,7 +2057,7 @@ class GenerationMixin:
|
||||
beam_scores = beam_scores.view((batch_size * num_beams,))
|
||||
|
||||
this_peer_finished = False # used by synced_gpus only
|
||||
while cur_len < max_length:
|
||||
while True:
|
||||
|
||||
if synced_gpus:
|
||||
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
|
||||
@@ -2058,7 +2087,7 @@ class GenerationMixin:
|
||||
# hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
|
||||
# cannot be generated both before and after the `F.log_softmax` operation.
|
||||
next_token_logits = self.adjust_logits_during_generation(
|
||||
next_token_logits, cur_len=cur_len, max_length=max_length
|
||||
next_token_logits, cur_len=cur_len, max_length=None
|
||||
)
|
||||
|
||||
next_token_scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
|
||||
@@ -2195,7 +2224,8 @@ class GenerationMixin:
|
||||
An instance of :class:`~transformers.StoppingCriteriaList`. List of instances of class derived from
|
||||
:class:`~transformers.StoppingCriteria` used to tell if the generation loop should stop.
|
||||
max_length (:obj:`int`, `optional`, defaults to 20):
|
||||
The maximum length of the sequence to be generated.
|
||||
**DEPRECATED**. Use :obj:`logits_processor` or :obj:`stopping_criteria` directly to cap the number of
|
||||
generated tokens. The maximum length of the sequence to be generated.
|
||||
pad_token_id (:obj:`int`, `optional`):
|
||||
The id of the `padding` token.
|
||||
eos_token_id (:obj:`int`, `optional`):
|
||||
@@ -2279,8 +2309,12 @@ class GenerationMixin:
|
||||
# init values
|
||||
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
||||
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
||||
max_length = max_length if max_length is not None else self.config.max_length
|
||||
validate_stopping_criteria(stopping_criteria, max_length)
|
||||
if max_length is not None:
|
||||
warnings.warn(
|
||||
"`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.",
|
||||
UserWarning,
|
||||
)
|
||||
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
|
||||
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
|
||||
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
||||
output_scores = output_scores if output_scores is not None else self.config.output_scores
|
||||
@@ -2324,7 +2358,7 @@ class GenerationMixin:
|
||||
beam_scores = beam_scores.view((batch_size * num_beams,))
|
||||
|
||||
this_peer_finished = False # used by synced_gpus only
|
||||
while cur_len < max_length:
|
||||
while True:
|
||||
|
||||
if synced_gpus:
|
||||
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
|
||||
@@ -2378,7 +2412,7 @@ class GenerationMixin:
|
||||
# hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
|
||||
# cannot be generated both before and after the `F.log_softmax` operation.
|
||||
next_token_logits = self.adjust_logits_during_generation(
|
||||
next_token_logits, cur_len=cur_len, max_length=max_length
|
||||
next_token_logits, cur_len=cur_len, max_length=None
|
||||
)
|
||||
|
||||
next_token_scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * group_size, vocab_size)
|
||||
|
||||
@@ -40,10 +40,10 @@ class StoppingCriteriaTestCase(unittest.TestCase):
|
||||
|
||||
self.assertFalse(criteria(input_ids, scores))
|
||||
|
||||
input_ids, scores = self._get_tensors(10)
|
||||
input_ids, scores = self._get_tensors(9)
|
||||
self.assertFalse(criteria(input_ids, scores))
|
||||
|
||||
input_ids, scores = self._get_tensors(11)
|
||||
input_ids, scores = self._get_tensors(10)
|
||||
self.assertTrue(criteria(input_ids, scores))
|
||||
|
||||
def test_max_length_criteria(self):
|
||||
@@ -52,10 +52,10 @@ class StoppingCriteriaTestCase(unittest.TestCase):
|
||||
input_ids, scores = self._get_tensors(5)
|
||||
self.assertFalse(criteria(input_ids, scores))
|
||||
|
||||
input_ids, scores = self._get_tensors(10)
|
||||
input_ids, scores = self._get_tensors(9)
|
||||
self.assertFalse(criteria(input_ids, scores))
|
||||
|
||||
input_ids, scores = self._get_tensors(11)
|
||||
input_ids, scores = self._get_tensors(10)
|
||||
self.assertTrue(criteria(input_ids, scores))
|
||||
|
||||
def test_max_time_criteria(self):
|
||||
@@ -73,7 +73,6 @@ class StoppingCriteriaTestCase(unittest.TestCase):
|
||||
with self.assertWarns(UserWarning):
|
||||
validate_stopping_criteria(StoppingCriteriaList([MaxLengthCriteria(10)]), 11)
|
||||
|
||||
stopping_criteria = StoppingCriteriaList()
|
||||
validate_stopping_criteria(stopping_criteria, 11)
|
||||
stopping_criteria = validate_stopping_criteria(StoppingCriteriaList(), 11)
|
||||
|
||||
self.assertEqual(len(stopping_criteria), 1)
|
||||
|
||||
@@ -1358,13 +1358,14 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
bos_token_id=bart_model.config.bos_token_id,
|
||||
)
|
||||
|
||||
bart_model.greedy_search(
|
||||
input_ids,
|
||||
max_length=max_length,
|
||||
pad_token_id=bart_model.config.pad_token_id,
|
||||
eos_token_id=bart_model.config.eos_token_id,
|
||||
**model_kwargs,
|
||||
)
|
||||
with self.assertWarns(UserWarning):
|
||||
bart_model.greedy_search(
|
||||
input_ids,
|
||||
max_length=max_length,
|
||||
pad_token_id=bart_model.config.pad_token_id,
|
||||
eos_token_id=bart_model.config.eos_token_id,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
def test_max_length_backward_compat_sample(self):
|
||||
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
||||
@@ -1381,13 +1382,14 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
bos_token_id=bart_model.config.bos_token_id,
|
||||
)
|
||||
with torch.no_grad():
|
||||
bart_model.sample(
|
||||
input_ids,
|
||||
max_length=max_length,
|
||||
pad_token_id=bart_model.config.pad_token_id,
|
||||
eos_token_id=bart_model.config.eos_token_id,
|
||||
**model_kwargs,
|
||||
)
|
||||
with self.assertWarns(UserWarning):
|
||||
bart_model.sample(
|
||||
input_ids,
|
||||
max_length=max_length,
|
||||
pad_token_id=bart_model.config.pad_token_id,
|
||||
eos_token_id=bart_model.config.eos_token_id,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
def test_max_length_backward_compat_beam_search(self):
|
||||
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
||||
@@ -1413,9 +1415,10 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
num_beams=num_beams,
|
||||
device=torch_device,
|
||||
)
|
||||
_ = bart_model.beam_search(
|
||||
input_ids, num_beams=num_beams, max_length=max_length, beam_scorer=beam_scorer, **model_kwargs
|
||||
)
|
||||
with self.assertWarns(UserWarning):
|
||||
_ = bart_model.beam_search(
|
||||
input_ids, num_beams=num_beams, max_length=max_length, beam_scorer=beam_scorer, **model_kwargs
|
||||
)
|
||||
|
||||
def test_max_length_backward_compat_group_beam_search(self):
|
||||
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
||||
@@ -1445,9 +1448,10 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
num_beam_hyps_to_keep=num_return_sequences,
|
||||
num_beam_groups=num_beam_groups,
|
||||
)
|
||||
bart_model.group_beam_search(
|
||||
input_ids, diverse_beam_scorer, num_beams=num_beams, max_length=max_length, **model_kwargs
|
||||
)
|
||||
with self.assertWarns(UserWarning):
|
||||
bart_model.group_beam_search(
|
||||
input_ids, diverse_beam_scorer, num_beams=num_beams, max_length=max_length, **model_kwargs
|
||||
)
|
||||
|
||||
def test_max_length_warning_if_different(self):
|
||||
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
||||
|
||||
Reference in New Issue
Block a user