Removed max_length from being mandatory within generate. (#11314)
* Removed `max_length` from being mandatory within `generate`. - Moving on to fully using `StoppingCriteria` for `greedy` and `sample` modes. - `max_length` still used for `beam_search` and `group_beam_search` (Follow up PR) - Fixes a bug with MaxLengthStoppingCriteria (we should stop as soon a we hit the max_length, the comparison needs to be or equal, that affects the tests). - Added options to use `logits_processor` and `stopping_criteria` directly within `generate` function (so some users can define their own `logits_processor` and `stopping_criteria`). - Modified the backward compat tests to make sure we issue a warning. * Fix `max_length` argument in `generate`. * Moving validate to being functional. - Renamed `smax_length` to `stoppping_max_length`. * Removing `logits_processor` and `stopping_criteria` from `generate` arguments. * Deepcopy. * Fix global variable name.
This commit is contained in:
@@ -1,6 +1,7 @@
|
|||||||
import time
|
import time
|
||||||
import warnings
|
import warnings
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
|
from copy import deepcopy
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -8,7 +9,7 @@ import torch
|
|||||||
from .file_utils import add_start_docstrings
|
from .file_utils import add_start_docstrings
|
||||||
|
|
||||||
|
|
||||||
LOGITS_PROCESSOR_INPUTS_DOCSTRING = r"""
|
STOPPING_CRITERIA_INPUTS_DOCSTRING = r"""
|
||||||
Args:
|
Args:
|
||||||
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
|
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
|
||||||
Indices of input sequence tokens in the vocabulary.
|
Indices of input sequence tokens in the vocabulary.
|
||||||
@@ -33,7 +34,7 @@ LOGITS_PROCESSOR_INPUTS_DOCSTRING = r"""
|
|||||||
class StoppingCriteria(ABC):
|
class StoppingCriteria(ABC):
|
||||||
"""Abstract base class for all stopping criteria that can be applied during generation."""
|
"""Abstract base class for all stopping criteria that can be applied during generation."""
|
||||||
|
|
||||||
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
|
||||||
def __call__(self, input_ids: torch.LongTensor, score: torch.FloatTensor, **kwargs) -> bool:
|
def __call__(self, input_ids: torch.LongTensor, score: torch.FloatTensor, **kwargs) -> bool:
|
||||||
raise NotImplementedError("StoppingCriteria needs to be subclassed")
|
raise NotImplementedError("StoppingCriteria needs to be subclassed")
|
||||||
|
|
||||||
@@ -51,9 +52,9 @@ class MaxLengthCriteria(StoppingCriteria):
|
|||||||
def __init__(self, max_length: int):
|
def __init__(self, max_length: int):
|
||||||
self.max_length = max_length
|
self.max_length = max_length
|
||||||
|
|
||||||
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
|
||||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
||||||
return input_ids.shape[-1] > self.max_length
|
return input_ids.shape[-1] >= self.max_length
|
||||||
|
|
||||||
|
|
||||||
class MaxTimeCriteria(StoppingCriteria):
|
class MaxTimeCriteria(StoppingCriteria):
|
||||||
@@ -73,25 +74,29 @@ class MaxTimeCriteria(StoppingCriteria):
|
|||||||
self.max_time = max_time
|
self.max_time = max_time
|
||||||
self.initial_timestamp = time.time() if initial_timestamp is None else initial_timestamp
|
self.initial_timestamp = time.time() if initial_timestamp is None else initial_timestamp
|
||||||
|
|
||||||
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
|
||||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
||||||
return time.time() - self.initial_timestamp > self.max_time
|
return time.time() - self.initial_timestamp > self.max_time
|
||||||
|
|
||||||
|
|
||||||
class StoppingCriteriaList(list):
|
class StoppingCriteriaList(list):
|
||||||
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
|
||||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
||||||
return any(criteria(input_ids, scores) for criteria in self)
|
return any(criteria(input_ids, scores) for criteria in self)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def max_length(self) -> Optional[int]:
|
||||||
|
for stopping_criterium in self:
|
||||||
|
if isinstance(stopping_criterium, MaxLengthCriteria):
|
||||||
|
return stopping_criterium.max_length
|
||||||
|
return None
|
||||||
|
|
||||||
def validate_stopping_criteria(stopping_criteria: StoppingCriteriaList, max_length: int):
|
|
||||||
found = False
|
def validate_stopping_criteria(stopping_criteria: StoppingCriteriaList, max_length: int) -> StoppingCriteriaList:
|
||||||
for stopping_criterium in stopping_criteria:
|
stopping_max_length = stopping_criteria.max_length
|
||||||
if isinstance(stopping_criterium, MaxLengthCriteria):
|
new_stopping_criteria = deepcopy(stopping_criteria)
|
||||||
found = True
|
if stopping_max_length is not None and stopping_max_length != max_length:
|
||||||
if stopping_criterium.max_length != max_length:
|
warnings.warn("You set different `max_length` for stopping criteria and `max_length` parameter", UserWarning)
|
||||||
warnings.warn(
|
elif stopping_max_length is None:
|
||||||
"You set different `max_length` for stopping criteria and `max_length` parameter", UserWarning
|
new_stopping_criteria.append(MaxLengthCriteria(max_length=max_length))
|
||||||
)
|
return new_stopping_criteria
|
||||||
if not found:
|
|
||||||
stopping_criteria.append(MaxLengthCriteria(max_length=max_length))
|
|
||||||
|
|||||||
@@ -14,6 +14,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import warnings
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
|
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
@@ -564,6 +565,7 @@ class GenerationMixin:
|
|||||||
This class returns a :obj:`~transformers.LogitsProcessorList` list object that contains all relevant
|
This class returns a :obj:`~transformers.LogitsProcessorList` list object that contains all relevant
|
||||||
:obj:`~transformers.LogitsProcessor` instances used to modify the scores of the language model head.
|
:obj:`~transformers.LogitsProcessor` instances used to modify the scores of the language model head.
|
||||||
"""
|
"""
|
||||||
|
processors = LogitsProcessorList()
|
||||||
|
|
||||||
# init warp parameters
|
# init warp parameters
|
||||||
repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.repetition_penalty
|
repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.repetition_penalty
|
||||||
@@ -589,7 +591,6 @@ class GenerationMixin:
|
|||||||
remove_invalid_values if remove_invalid_values is not None else self.config.remove_invalid_values
|
remove_invalid_values if remove_invalid_values is not None else self.config.remove_invalid_values
|
||||||
)
|
)
|
||||||
# instantiate processors list
|
# instantiate processors list
|
||||||
processors = LogitsProcessorList()
|
|
||||||
|
|
||||||
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
|
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
|
||||||
# all samplers can be found in `generation_utils_samplers.py`
|
# all samplers can be found in `generation_utils_samplers.py`
|
||||||
@@ -629,7 +630,6 @@ class GenerationMixin:
|
|||||||
max_length: Optional[int],
|
max_length: Optional[int],
|
||||||
max_time: Optional[float],
|
max_time: Optional[float],
|
||||||
) -> StoppingCriteriaList:
|
) -> StoppingCriteriaList:
|
||||||
|
|
||||||
stopping_criteria = StoppingCriteriaList()
|
stopping_criteria = StoppingCriteriaList()
|
||||||
if max_length is not None:
|
if max_length is not None:
|
||||||
stopping_criteria.append(MaxLengthCriteria(max_length=max_length))
|
stopping_criteria.append(MaxLengthCriteria(max_length=max_length))
|
||||||
@@ -859,9 +859,9 @@ class GenerationMixin:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# set init values
|
# set init values
|
||||||
|
max_length = max_length if max_length is not None else self.config.max_length
|
||||||
num_beams = num_beams if num_beams is not None else self.config.num_beams
|
num_beams = num_beams if num_beams is not None else self.config.num_beams
|
||||||
num_beam_groups = num_beam_groups if num_beam_groups is not None else self.config.num_beam_groups
|
num_beam_groups = num_beam_groups if num_beam_groups is not None else self.config.num_beam_groups
|
||||||
max_length = max_length if max_length is not None else self.config.max_length
|
|
||||||
do_sample = do_sample if do_sample is not None else self.config.do_sample
|
do_sample = do_sample if do_sample is not None else self.config.do_sample
|
||||||
num_return_sequences = (
|
num_return_sequences = (
|
||||||
num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
|
num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
|
||||||
@@ -958,10 +958,13 @@ class GenerationMixin:
|
|||||||
remove_invalid_values=remove_invalid_values,
|
remove_invalid_values=remove_invalid_values,
|
||||||
)
|
)
|
||||||
|
|
||||||
stopping_criteria = self._get_stopping_criteria(
|
stopping_criteria = self._get_stopping_criteria(max_length=max_length, max_time=max_time)
|
||||||
max_length=max_length,
|
if max_length is not None:
|
||||||
max_time=max_time,
|
warnings.warn(
|
||||||
)
|
"`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.",
|
||||||
|
UserWarning,
|
||||||
|
)
|
||||||
|
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
|
||||||
|
|
||||||
if is_greedy_gen_mode:
|
if is_greedy_gen_mode:
|
||||||
if num_return_sequences > 1:
|
if num_return_sequences > 1:
|
||||||
@@ -974,7 +977,6 @@ class GenerationMixin:
|
|||||||
input_ids,
|
input_ids,
|
||||||
logits_processor=logits_processor,
|
logits_processor=logits_processor,
|
||||||
stopping_criteria=stopping_criteria,
|
stopping_criteria=stopping_criteria,
|
||||||
max_length=max_length,
|
|
||||||
pad_token_id=pad_token_id,
|
pad_token_id=pad_token_id,
|
||||||
eos_token_id=eos_token_id,
|
eos_token_id=eos_token_id,
|
||||||
output_scores=output_scores,
|
output_scores=output_scores,
|
||||||
@@ -1003,7 +1005,6 @@ class GenerationMixin:
|
|||||||
logits_processor=logits_processor,
|
logits_processor=logits_processor,
|
||||||
logits_warper=logits_warper,
|
logits_warper=logits_warper,
|
||||||
stopping_criteria=stopping_criteria,
|
stopping_criteria=stopping_criteria,
|
||||||
max_length=max_length,
|
|
||||||
pad_token_id=pad_token_id,
|
pad_token_id=pad_token_id,
|
||||||
eos_token_id=eos_token_id,
|
eos_token_id=eos_token_id,
|
||||||
output_scores=output_scores,
|
output_scores=output_scores,
|
||||||
@@ -1021,9 +1022,12 @@ class GenerationMixin:
|
|||||||
if num_return_sequences > num_beams:
|
if num_return_sequences > num_beams:
|
||||||
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")
|
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")
|
||||||
|
|
||||||
|
if stopping_criteria.max_length is None:
|
||||||
|
raise ValueError("`max_length` needs to be a stopping_criteria for now.")
|
||||||
|
|
||||||
beam_scorer = BeamSearchScorer(
|
beam_scorer = BeamSearchScorer(
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
max_length=max_length,
|
max_length=stopping_criteria.max_length,
|
||||||
num_beams=num_beams,
|
num_beams=num_beams,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
length_penalty=length_penalty,
|
length_penalty=length_penalty,
|
||||||
@@ -1039,7 +1043,6 @@ class GenerationMixin:
|
|||||||
beam_scorer,
|
beam_scorer,
|
||||||
logits_processor=logits_processor,
|
logits_processor=logits_processor,
|
||||||
stopping_criteria=stopping_criteria,
|
stopping_criteria=stopping_criteria,
|
||||||
max_length=max_length,
|
|
||||||
pad_token_id=pad_token_id,
|
pad_token_id=pad_token_id,
|
||||||
eos_token_id=eos_token_id,
|
eos_token_id=eos_token_id,
|
||||||
output_scores=output_scores,
|
output_scores=output_scores,
|
||||||
@@ -1056,9 +1059,11 @@ class GenerationMixin:
|
|||||||
batch_size = input_ids.shape[0] * num_return_sequences
|
batch_size = input_ids.shape[0] * num_return_sequences
|
||||||
|
|
||||||
length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
|
length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
|
||||||
|
if stopping_criteria.max_length is None:
|
||||||
|
raise ValueError("`max_length` needs to be a stopping_criteria for now.")
|
||||||
beam_scorer = BeamSearchScorer(
|
beam_scorer = BeamSearchScorer(
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
max_length=max_length,
|
max_length=stopping_criteria.max_length,
|
||||||
num_beams=num_beams,
|
num_beams=num_beams,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
length_penalty=length_penalty,
|
length_penalty=length_penalty,
|
||||||
@@ -1079,7 +1084,6 @@ class GenerationMixin:
|
|||||||
logits_processor=logits_processor,
|
logits_processor=logits_processor,
|
||||||
logits_warper=logits_warper,
|
logits_warper=logits_warper,
|
||||||
stopping_criteria=stopping_criteria,
|
stopping_criteria=stopping_criteria,
|
||||||
max_length=max_length,
|
|
||||||
pad_token_id=pad_token_id,
|
pad_token_id=pad_token_id,
|
||||||
eos_token_id=eos_token_id,
|
eos_token_id=eos_token_id,
|
||||||
output_scores=output_scores,
|
output_scores=output_scores,
|
||||||
@@ -1100,10 +1104,13 @@ class GenerationMixin:
|
|||||||
if num_beams % num_beam_groups != 0:
|
if num_beams % num_beam_groups != 0:
|
||||||
raise ValueError("`num_beams` should be divisible by `num_beam_groups` for group beam search.")
|
raise ValueError("`num_beams` should be divisible by `num_beam_groups` for group beam search.")
|
||||||
|
|
||||||
|
if stopping_criteria.max_length is None:
|
||||||
|
raise ValueError("`max_length` needs to be a stopping_criteria for now.")
|
||||||
|
|
||||||
diverse_beam_scorer = BeamSearchScorer(
|
diverse_beam_scorer = BeamSearchScorer(
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
max_length=max_length,
|
|
||||||
num_beams=num_beams,
|
num_beams=num_beams,
|
||||||
|
max_length=stopping_criteria.max_length,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
length_penalty=length_penalty,
|
length_penalty=length_penalty,
|
||||||
do_early_stopping=early_stopping,
|
do_early_stopping=early_stopping,
|
||||||
@@ -1119,7 +1126,6 @@ class GenerationMixin:
|
|||||||
diverse_beam_scorer,
|
diverse_beam_scorer,
|
||||||
logits_processor=logits_processor,
|
logits_processor=logits_processor,
|
||||||
stopping_criteria=stopping_criteria,
|
stopping_criteria=stopping_criteria,
|
||||||
max_length=max_length,
|
|
||||||
pad_token_id=pad_token_id,
|
pad_token_id=pad_token_id,
|
||||||
eos_token_id=eos_token_id,
|
eos_token_id=eos_token_id,
|
||||||
output_scores=output_scores,
|
output_scores=output_scores,
|
||||||
@@ -1160,7 +1166,8 @@ class GenerationMixin:
|
|||||||
:class:`~transformers.StoppingCriteria` used to tell if the generation loop should stop.
|
:class:`~transformers.StoppingCriteria` used to tell if the generation loop should stop.
|
||||||
|
|
||||||
max_length (:obj:`int`, `optional`, defaults to 20):
|
max_length (:obj:`int`, `optional`, defaults to 20):
|
||||||
The maximum length of the sequence to be generated.
|
**DEPRECATED**. Use :obj:`logits_processor` or :obj:`stopping_criteria` directly to cap the number of
|
||||||
|
generated tokens. The maximum length of the sequence to be generated.
|
||||||
pad_token_id (:obj:`int`, `optional`):
|
pad_token_id (:obj:`int`, `optional`):
|
||||||
The id of the `padding` token.
|
The id of the `padding` token.
|
||||||
eos_token_id (:obj:`int`, `optional`):
|
eos_token_id (:obj:`int`, `optional`):
|
||||||
@@ -1220,8 +1227,12 @@ class GenerationMixin:
|
|||||||
# init values
|
# init values
|
||||||
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
||||||
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
||||||
max_length = max_length if max_length is not None else self.config.max_length
|
if max_length is not None:
|
||||||
validate_stopping_criteria(stopping_criteria, max_length)
|
warnings.warn(
|
||||||
|
"`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.",
|
||||||
|
UserWarning,
|
||||||
|
)
|
||||||
|
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
|
||||||
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
|
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
|
||||||
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
||||||
output_scores = output_scores if output_scores is not None else self.config.output_scores
|
output_scores = output_scores if output_scores is not None else self.config.output_scores
|
||||||
@@ -1251,7 +1262,7 @@ class GenerationMixin:
|
|||||||
cur_len = input_ids.shape[-1]
|
cur_len = input_ids.shape[-1]
|
||||||
|
|
||||||
this_peer_finished = False # used by synced_gpus only
|
this_peer_finished = False # used by synced_gpus only
|
||||||
while cur_len < max_length:
|
while True:
|
||||||
|
|
||||||
if synced_gpus:
|
if synced_gpus:
|
||||||
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
|
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
|
||||||
@@ -1384,7 +1395,8 @@ class GenerationMixin:
|
|||||||
:class:`~transformers.LogitsWarper` used to warp the prediction score distribution of the language
|
:class:`~transformers.LogitsWarper` used to warp the prediction score distribution of the language
|
||||||
modeling head applied before multinomial sampling at each generation step.
|
modeling head applied before multinomial sampling at each generation step.
|
||||||
max_length (:obj:`int`, `optional`, defaults to 20):
|
max_length (:obj:`int`, `optional`, defaults to 20):
|
||||||
The maximum length of the sequence to be generated.
|
**DEPRECATED**. Use :obj:`logits_processor` or :obj:`stopping_criteria` directly to cap the number of
|
||||||
|
generated tokens. The maximum length of the sequence to be generated.
|
||||||
pad_token_id (:obj:`int`, `optional`):
|
pad_token_id (:obj:`int`, `optional`):
|
||||||
The id of the `padding` token.
|
The id of the `padding` token.
|
||||||
eos_token_id (:obj:`int`, `optional`):
|
eos_token_id (:obj:`int`, `optional`):
|
||||||
@@ -1452,8 +1464,12 @@ class GenerationMixin:
|
|||||||
# init values
|
# init values
|
||||||
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
||||||
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
||||||
max_length = max_length if max_length is not None else self.config.max_length
|
if max_length is not None:
|
||||||
validate_stopping_criteria(stopping_criteria, max_length)
|
warnings.warn(
|
||||||
|
"`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.",
|
||||||
|
UserWarning,
|
||||||
|
)
|
||||||
|
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
|
||||||
logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
|
logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
|
||||||
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
|
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
|
||||||
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
||||||
@@ -1485,7 +1501,7 @@ class GenerationMixin:
|
|||||||
|
|
||||||
this_peer_finished = False # used by synced_gpus only
|
this_peer_finished = False # used by synced_gpus only
|
||||||
# auto-regressive generation
|
# auto-regressive generation
|
||||||
while cur_len < max_length:
|
while True:
|
||||||
|
|
||||||
if synced_gpus:
|
if synced_gpus:
|
||||||
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
|
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
|
||||||
@@ -1620,7 +1636,8 @@ class GenerationMixin:
|
|||||||
An instance of :class:`~transformers.StoppingCriteriaList`. List of instances of class derived from
|
An instance of :class:`~transformers.StoppingCriteriaList`. List of instances of class derived from
|
||||||
:class:`~transformers.StoppingCriteria` used to tell if the generation loop should stop.
|
:class:`~transformers.StoppingCriteria` used to tell if the generation loop should stop.
|
||||||
max_length (:obj:`int`, `optional`, defaults to 20):
|
max_length (:obj:`int`, `optional`, defaults to 20):
|
||||||
The maximum length of the sequence to be generated.
|
**DEPRECATED**. Use :obj:`logits_processor` or :obj:`stopping_criteria` directly to cap the number of
|
||||||
|
generated tokens. The maximum length of the sequence to be generated.
|
||||||
pad_token_id (:obj:`int`, `optional`):
|
pad_token_id (:obj:`int`, `optional`):
|
||||||
The id of the `padding` token.
|
The id of the `padding` token.
|
||||||
eos_token_id (:obj:`int`, `optional`):
|
eos_token_id (:obj:`int`, `optional`):
|
||||||
@@ -1700,8 +1717,14 @@ class GenerationMixin:
|
|||||||
# init values
|
# init values
|
||||||
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
||||||
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
||||||
max_length = max_length if max_length is not None else self.config.max_length
|
if max_length is not None:
|
||||||
validate_stopping_criteria(stopping_criteria, max_length)
|
warnings.warn(
|
||||||
|
"`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.",
|
||||||
|
UserWarning,
|
||||||
|
)
|
||||||
|
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
|
||||||
|
if len(stopping_criteria) == 0:
|
||||||
|
warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning)
|
||||||
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
|
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
|
||||||
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
||||||
output_scores = output_scores if output_scores is not None else self.config.output_scores
|
output_scores = output_scores if output_scores is not None else self.config.output_scores
|
||||||
@@ -1740,7 +1763,7 @@ class GenerationMixin:
|
|||||||
beam_scores = beam_scores.view((batch_size * num_beams,))
|
beam_scores = beam_scores.view((batch_size * num_beams,))
|
||||||
|
|
||||||
this_peer_finished = False # used by synced_gpus only
|
this_peer_finished = False # used by synced_gpus only
|
||||||
while cur_len < max_length:
|
while True:
|
||||||
|
|
||||||
if synced_gpus:
|
if synced_gpus:
|
||||||
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
|
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
|
||||||
@@ -1770,7 +1793,7 @@ class GenerationMixin:
|
|||||||
# hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
|
# hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
|
||||||
# cannot be generated both before and after the `F.log_softmax` operation.
|
# cannot be generated both before and after the `F.log_softmax` operation.
|
||||||
next_token_logits = self.adjust_logits_during_generation(
|
next_token_logits = self.adjust_logits_during_generation(
|
||||||
next_token_logits, cur_len=cur_len, max_length=max_length
|
next_token_logits, cur_len=cur_len, max_length=None
|
||||||
)
|
)
|
||||||
|
|
||||||
next_token_scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
|
next_token_scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
|
||||||
@@ -1907,7 +1930,8 @@ class GenerationMixin:
|
|||||||
:class:`~transformers.LogitsWarper` used to warp the prediction score distribution of the language
|
:class:`~transformers.LogitsWarper` used to warp the prediction score distribution of the language
|
||||||
modeling head applied before multinomial sampling at each generation step.
|
modeling head applied before multinomial sampling at each generation step.
|
||||||
max_length (:obj:`int`, `optional`, defaults to 20):
|
max_length (:obj:`int`, `optional`, defaults to 20):
|
||||||
The maximum length of the sequence to be generated.
|
**DEPRECATED**. Use :obj:`logits_processor` or :obj:`stopping_criteria` directly to cap the number of
|
||||||
|
generated tokens. The maximum length of the sequence to be generated.
|
||||||
pad_token_id (:obj:`int`, `optional`):
|
pad_token_id (:obj:`int`, `optional`):
|
||||||
The id of the `padding` token.
|
The id of the `padding` token.
|
||||||
eos_token_id (:obj:`int`, `optional`):
|
eos_token_id (:obj:`int`, `optional`):
|
||||||
@@ -1994,7 +2018,12 @@ class GenerationMixin:
|
|||||||
# init values
|
# init values
|
||||||
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
||||||
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
||||||
max_length = max_length if max_length is not None else self.config.max_length
|
if max_length is not None:
|
||||||
|
warnings.warn(
|
||||||
|
"`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.",
|
||||||
|
UserWarning,
|
||||||
|
)
|
||||||
|
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
|
||||||
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
|
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
|
||||||
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
||||||
output_scores = output_scores if output_scores is not None else self.config.output_scores
|
output_scores = output_scores if output_scores is not None else self.config.output_scores
|
||||||
@@ -2028,7 +2057,7 @@ class GenerationMixin:
|
|||||||
beam_scores = beam_scores.view((batch_size * num_beams,))
|
beam_scores = beam_scores.view((batch_size * num_beams,))
|
||||||
|
|
||||||
this_peer_finished = False # used by synced_gpus only
|
this_peer_finished = False # used by synced_gpus only
|
||||||
while cur_len < max_length:
|
while True:
|
||||||
|
|
||||||
if synced_gpus:
|
if synced_gpus:
|
||||||
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
|
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
|
||||||
@@ -2058,7 +2087,7 @@ class GenerationMixin:
|
|||||||
# hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
|
# hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
|
||||||
# cannot be generated both before and after the `F.log_softmax` operation.
|
# cannot be generated both before and after the `F.log_softmax` operation.
|
||||||
next_token_logits = self.adjust_logits_during_generation(
|
next_token_logits = self.adjust_logits_during_generation(
|
||||||
next_token_logits, cur_len=cur_len, max_length=max_length
|
next_token_logits, cur_len=cur_len, max_length=None
|
||||||
)
|
)
|
||||||
|
|
||||||
next_token_scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
|
next_token_scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
|
||||||
@@ -2195,7 +2224,8 @@ class GenerationMixin:
|
|||||||
An instance of :class:`~transformers.StoppingCriteriaList`. List of instances of class derived from
|
An instance of :class:`~transformers.StoppingCriteriaList`. List of instances of class derived from
|
||||||
:class:`~transformers.StoppingCriteria` used to tell if the generation loop should stop.
|
:class:`~transformers.StoppingCriteria` used to tell if the generation loop should stop.
|
||||||
max_length (:obj:`int`, `optional`, defaults to 20):
|
max_length (:obj:`int`, `optional`, defaults to 20):
|
||||||
The maximum length of the sequence to be generated.
|
**DEPRECATED**. Use :obj:`logits_processor` or :obj:`stopping_criteria` directly to cap the number of
|
||||||
|
generated tokens. The maximum length of the sequence to be generated.
|
||||||
pad_token_id (:obj:`int`, `optional`):
|
pad_token_id (:obj:`int`, `optional`):
|
||||||
The id of the `padding` token.
|
The id of the `padding` token.
|
||||||
eos_token_id (:obj:`int`, `optional`):
|
eos_token_id (:obj:`int`, `optional`):
|
||||||
@@ -2279,8 +2309,12 @@ class GenerationMixin:
|
|||||||
# init values
|
# init values
|
||||||
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
||||||
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
||||||
max_length = max_length if max_length is not None else self.config.max_length
|
if max_length is not None:
|
||||||
validate_stopping_criteria(stopping_criteria, max_length)
|
warnings.warn(
|
||||||
|
"`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.",
|
||||||
|
UserWarning,
|
||||||
|
)
|
||||||
|
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
|
||||||
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
|
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
|
||||||
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
||||||
output_scores = output_scores if output_scores is not None else self.config.output_scores
|
output_scores = output_scores if output_scores is not None else self.config.output_scores
|
||||||
@@ -2324,7 +2358,7 @@ class GenerationMixin:
|
|||||||
beam_scores = beam_scores.view((batch_size * num_beams,))
|
beam_scores = beam_scores.view((batch_size * num_beams,))
|
||||||
|
|
||||||
this_peer_finished = False # used by synced_gpus only
|
this_peer_finished = False # used by synced_gpus only
|
||||||
while cur_len < max_length:
|
while True:
|
||||||
|
|
||||||
if synced_gpus:
|
if synced_gpus:
|
||||||
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
|
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
|
||||||
@@ -2378,7 +2412,7 @@ class GenerationMixin:
|
|||||||
# hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
|
# hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
|
||||||
# cannot be generated both before and after the `F.log_softmax` operation.
|
# cannot be generated both before and after the `F.log_softmax` operation.
|
||||||
next_token_logits = self.adjust_logits_during_generation(
|
next_token_logits = self.adjust_logits_during_generation(
|
||||||
next_token_logits, cur_len=cur_len, max_length=max_length
|
next_token_logits, cur_len=cur_len, max_length=None
|
||||||
)
|
)
|
||||||
|
|
||||||
next_token_scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * group_size, vocab_size)
|
next_token_scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * group_size, vocab_size)
|
||||||
|
|||||||
@@ -40,10 +40,10 @@ class StoppingCriteriaTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
self.assertFalse(criteria(input_ids, scores))
|
self.assertFalse(criteria(input_ids, scores))
|
||||||
|
|
||||||
input_ids, scores = self._get_tensors(10)
|
input_ids, scores = self._get_tensors(9)
|
||||||
self.assertFalse(criteria(input_ids, scores))
|
self.assertFalse(criteria(input_ids, scores))
|
||||||
|
|
||||||
input_ids, scores = self._get_tensors(11)
|
input_ids, scores = self._get_tensors(10)
|
||||||
self.assertTrue(criteria(input_ids, scores))
|
self.assertTrue(criteria(input_ids, scores))
|
||||||
|
|
||||||
def test_max_length_criteria(self):
|
def test_max_length_criteria(self):
|
||||||
@@ -52,10 +52,10 @@ class StoppingCriteriaTestCase(unittest.TestCase):
|
|||||||
input_ids, scores = self._get_tensors(5)
|
input_ids, scores = self._get_tensors(5)
|
||||||
self.assertFalse(criteria(input_ids, scores))
|
self.assertFalse(criteria(input_ids, scores))
|
||||||
|
|
||||||
input_ids, scores = self._get_tensors(10)
|
input_ids, scores = self._get_tensors(9)
|
||||||
self.assertFalse(criteria(input_ids, scores))
|
self.assertFalse(criteria(input_ids, scores))
|
||||||
|
|
||||||
input_ids, scores = self._get_tensors(11)
|
input_ids, scores = self._get_tensors(10)
|
||||||
self.assertTrue(criteria(input_ids, scores))
|
self.assertTrue(criteria(input_ids, scores))
|
||||||
|
|
||||||
def test_max_time_criteria(self):
|
def test_max_time_criteria(self):
|
||||||
@@ -73,7 +73,6 @@ class StoppingCriteriaTestCase(unittest.TestCase):
|
|||||||
with self.assertWarns(UserWarning):
|
with self.assertWarns(UserWarning):
|
||||||
validate_stopping_criteria(StoppingCriteriaList([MaxLengthCriteria(10)]), 11)
|
validate_stopping_criteria(StoppingCriteriaList([MaxLengthCriteria(10)]), 11)
|
||||||
|
|
||||||
stopping_criteria = StoppingCriteriaList()
|
stopping_criteria = validate_stopping_criteria(StoppingCriteriaList(), 11)
|
||||||
validate_stopping_criteria(stopping_criteria, 11)
|
|
||||||
|
|
||||||
self.assertEqual(len(stopping_criteria), 1)
|
self.assertEqual(len(stopping_criteria), 1)
|
||||||
|
|||||||
@@ -1358,13 +1358,14 @@ class GenerationIntegrationTests(unittest.TestCase):
|
|||||||
bos_token_id=bart_model.config.bos_token_id,
|
bos_token_id=bart_model.config.bos_token_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
bart_model.greedy_search(
|
with self.assertWarns(UserWarning):
|
||||||
input_ids,
|
bart_model.greedy_search(
|
||||||
max_length=max_length,
|
input_ids,
|
||||||
pad_token_id=bart_model.config.pad_token_id,
|
max_length=max_length,
|
||||||
eos_token_id=bart_model.config.eos_token_id,
|
pad_token_id=bart_model.config.pad_token_id,
|
||||||
**model_kwargs,
|
eos_token_id=bart_model.config.eos_token_id,
|
||||||
)
|
**model_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
def test_max_length_backward_compat_sample(self):
|
def test_max_length_backward_compat_sample(self):
|
||||||
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
||||||
@@ -1381,13 +1382,14 @@ class GenerationIntegrationTests(unittest.TestCase):
|
|||||||
bos_token_id=bart_model.config.bos_token_id,
|
bos_token_id=bart_model.config.bos_token_id,
|
||||||
)
|
)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
bart_model.sample(
|
with self.assertWarns(UserWarning):
|
||||||
input_ids,
|
bart_model.sample(
|
||||||
max_length=max_length,
|
input_ids,
|
||||||
pad_token_id=bart_model.config.pad_token_id,
|
max_length=max_length,
|
||||||
eos_token_id=bart_model.config.eos_token_id,
|
pad_token_id=bart_model.config.pad_token_id,
|
||||||
**model_kwargs,
|
eos_token_id=bart_model.config.eos_token_id,
|
||||||
)
|
**model_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
def test_max_length_backward_compat_beam_search(self):
|
def test_max_length_backward_compat_beam_search(self):
|
||||||
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
||||||
@@ -1413,9 +1415,10 @@ class GenerationIntegrationTests(unittest.TestCase):
|
|||||||
num_beams=num_beams,
|
num_beams=num_beams,
|
||||||
device=torch_device,
|
device=torch_device,
|
||||||
)
|
)
|
||||||
_ = bart_model.beam_search(
|
with self.assertWarns(UserWarning):
|
||||||
input_ids, num_beams=num_beams, max_length=max_length, beam_scorer=beam_scorer, **model_kwargs
|
_ = bart_model.beam_search(
|
||||||
)
|
input_ids, num_beams=num_beams, max_length=max_length, beam_scorer=beam_scorer, **model_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
def test_max_length_backward_compat_group_beam_search(self):
|
def test_max_length_backward_compat_group_beam_search(self):
|
||||||
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
||||||
@@ -1445,9 +1448,10 @@ class GenerationIntegrationTests(unittest.TestCase):
|
|||||||
num_beam_hyps_to_keep=num_return_sequences,
|
num_beam_hyps_to_keep=num_return_sequences,
|
||||||
num_beam_groups=num_beam_groups,
|
num_beam_groups=num_beam_groups,
|
||||||
)
|
)
|
||||||
bart_model.group_beam_search(
|
with self.assertWarns(UserWarning):
|
||||||
input_ids, diverse_beam_scorer, num_beams=num_beams, max_length=max_length, **model_kwargs
|
bart_model.group_beam_search(
|
||||||
)
|
input_ids, diverse_beam_scorer, num_beams=num_beams, max_length=max_length, **model_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
def test_max_length_warning_if_different(self):
|
def test_max_length_warning_if_different(self):
|
||||||
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
||||||
|
|||||||
Reference in New Issue
Block a user