Add custom stop token ids for generation (#20727)
* Add StopIdStoppingCriteria * add a working test for stop id criteria * add to global scope * add stop_ids to generate * add pipeline test * use tokenizer encode in test * add test to generation utils * reformat * fixup * make-fix-copies * rename to stop_token_id * use stop_tokens instead * add to text to text generation * make fixup * make repo-consistency * Add support for list of ints for eos_token_id inside generation/utils.py * Instead of having if elses, cast the eos_token_id into a List[int] * Add List[int] support for logits_process.py * add List[int] for beam_search.py * add List[int] for forced_eos_token_id * revert stop token id stopping criteria changes * make fixup * fix tests * add eos_token_id to generation/utils.py and added tests test_utils.py * add eos_token_id type hints and fix for pad tokens * add comments * remove some prints and remove forced false test * fix * put back test_stop_sequence_stopping_criteria * remove unused import and make fixup * add a none check * update docstring * add more docstring for list ints * make fixup
This commit is contained in:
@@ -16,7 +16,7 @@
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import UserDict
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -42,8 +42,8 @@ PROCESS_INPUTS_DOCSTRING = r"""
|
||||
Beam indices indicating to which beam hypothesis the `next_tokens` correspond.
|
||||
pad_token_id (`int`, *optional*):
|
||||
The id of the *padding* token.
|
||||
eos_token_id (`int`, *optional*):
|
||||
The id of the *end-of-sequence* token.
|
||||
eos_token_id (`Union[int, List[int]]`, *optional*):
|
||||
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
|
||||
|
||||
Return:
|
||||
`UserDict`: A dictionary composed of the fields as defined above:
|
||||
@@ -74,8 +74,8 @@ FINALIZE_INPUTS_DOCSTRING = r"""
|
||||
The beam indices indicating to which beam the `final_beam_tokens` shall be added.
|
||||
pad_token_id (`int`, *optional*):
|
||||
The id of the *padding* token.
|
||||
eos_token_id (`int`, *optional*):
|
||||
The id of the *end-of-sequence* token.
|
||||
eos_token_id (`Union[int, List[int]]`, *optional*):
|
||||
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
|
||||
|
||||
Return:
|
||||
`torch.LongTensor` of shape `(batch_size * num_return_sequences, sequence_length)`: The generated sequences.
|
||||
@@ -212,7 +212,7 @@ class BeamSearchScorer(BeamScorer):
|
||||
next_tokens: torch.LongTensor,
|
||||
next_indices: torch.LongTensor,
|
||||
pad_token_id: Optional[int] = None,
|
||||
eos_token_id: Optional[int] = None,
|
||||
eos_token_id: Optional[Union[int, List[int]]] = None,
|
||||
beam_indices: Optional[torch.LongTensor] = None,
|
||||
) -> Tuple[torch.Tensor]:
|
||||
cur_len = input_ids.shape[-1]
|
||||
@@ -234,6 +234,9 @@ class BeamSearchScorer(BeamScorer):
|
||||
next_beam_tokens = torch.zeros((batch_size, self.group_size), dtype=next_tokens.dtype, device=device)
|
||||
next_beam_indices = torch.zeros((batch_size, self.group_size), dtype=next_indices.dtype, device=device)
|
||||
|
||||
if isinstance(eos_token_id, int):
|
||||
eos_token_id = [eos_token_id]
|
||||
|
||||
for batch_idx, beam_hyp in enumerate(self._beam_hyps):
|
||||
if self._done[batch_idx]:
|
||||
if self.num_beams < len(beam_hyp):
|
||||
@@ -253,7 +256,7 @@ class BeamSearchScorer(BeamScorer):
|
||||
):
|
||||
batch_beam_idx = batch_idx * self.group_size + next_index
|
||||
# add to generated hypotheses if end of sentence
|
||||
if (eos_token_id is not None) and (next_token.item() == eos_token_id):
|
||||
if (eos_token_id is not None) and (next_token.item() in eos_token_id):
|
||||
# if beam_token does not belong to top num_beams tokens, it should not be added
|
||||
is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.group_size
|
||||
if is_beam_token_worse_than_top_num_beams:
|
||||
@@ -307,11 +310,14 @@ class BeamSearchScorer(BeamScorer):
|
||||
final_beam_indices: torch.LongTensor,
|
||||
max_length: int,
|
||||
pad_token_id: Optional[int] = None,
|
||||
eos_token_id: Optional[int] = None,
|
||||
eos_token_id: Optional[Union[int, List[int]]] = None,
|
||||
beam_indices: Optional[torch.LongTensor] = None,
|
||||
) -> Tuple[torch.LongTensor]:
|
||||
batch_size = len(self._beam_hyps)
|
||||
|
||||
if isinstance(eos_token_id, int):
|
||||
eos_token_id = [eos_token_id]
|
||||
|
||||
# finalize all open beam hypotheses and add to generated hypotheses
|
||||
for batch_idx, beam_hyp in enumerate(self._beam_hyps):
|
||||
if self._done[batch_idx]:
|
||||
@@ -376,7 +382,8 @@ class BeamSearchScorer(BeamScorer):
|
||||
indices[i, : len(best_idx)] = torch.tensor(best_idx)
|
||||
|
||||
if sent_lengths[i] < sent_max_len:
|
||||
decoded[i, sent_lengths[i]] = eos_token_id
|
||||
# inserting only the first eos_token_id
|
||||
decoded[i, sent_lengths[i]] = eos_token_id[0]
|
||||
|
||||
return UserDict(
|
||||
{
|
||||
@@ -491,7 +498,7 @@ class ConstrainedBeamSearchScorer(BeamScorer):
|
||||
next_indices: torch.LongTensor,
|
||||
scores_for_all_vocab: torch.FloatTensor,
|
||||
pad_token_id: Optional[int] = None,
|
||||
eos_token_id: Optional[int] = None,
|
||||
eos_token_id: Optional[Union[int, List[int]]] = None,
|
||||
) -> Tuple[torch.Tensor]:
|
||||
r"""
|
||||
Args:
|
||||
@@ -512,8 +519,8 @@ class ConstrainedBeamSearchScorer(BeamScorer):
|
||||
The scores of all tokens in the vocabulary for each of the beam hypotheses.
|
||||
pad_token_id (`int`, *optional*):
|
||||
The id of the *padding* token.
|
||||
eos_token_id (`int`, *optional*):
|
||||
The id of the *end-of-sequence* token.
|
||||
eos_token_id (`Union[int, List[int]]`, *optional*):
|
||||
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
|
||||
|
||||
Return:
|
||||
`UserDict`: A dictionary composed of the fields as defined above:
|
||||
@@ -549,6 +556,9 @@ class ConstrainedBeamSearchScorer(BeamScorer):
|
||||
next_beam_tokens = torch.zeros((batch_size, self.group_size), dtype=next_tokens.dtype, device=device)
|
||||
next_beam_indices = torch.zeros((batch_size, self.group_size), dtype=next_indices.dtype, device=device)
|
||||
|
||||
if isinstance(eos_token_id, int):
|
||||
eos_token_id = [eos_token_id]
|
||||
|
||||
for batch_idx, beam_hyp in enumerate(self._beam_hyps):
|
||||
if self._done[batch_idx]:
|
||||
if self.num_beams < len(beam_hyp):
|
||||
@@ -568,7 +578,7 @@ class ConstrainedBeamSearchScorer(BeamScorer):
|
||||
):
|
||||
batch_beam_idx = batch_idx * self.group_size + next_index
|
||||
# add to generated hypotheses if end of sentence
|
||||
if (eos_token_id is not None) and (next_token.item() == eos_token_id):
|
||||
if (eos_token_id is not None) and (next_token.item() in eos_token_id):
|
||||
|
||||
# if beam_token does not belong to top num_beams tokens, it should not be added
|
||||
is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.group_size
|
||||
@@ -773,10 +783,13 @@ class ConstrainedBeamSearchScorer(BeamScorer):
|
||||
final_beam_indices: torch.LongTensor,
|
||||
max_length: int,
|
||||
pad_token_id: Optional[int] = None,
|
||||
eos_token_id: Optional[int] = None,
|
||||
eos_token_id: Optional[Union[int, List[int]]] = None,
|
||||
) -> Tuple[torch.LongTensor]:
|
||||
batch_size = len(self._beam_hyps)
|
||||
|
||||
if isinstance(eos_token_id, int):
|
||||
eos_token_id = [eos_token_id]
|
||||
|
||||
# finalize all open beam hypotheses and add to generated hypotheses
|
||||
for batch_idx, beam_hyp in enumerate(self._beam_hyps):
|
||||
if self._done[batch_idx]:
|
||||
@@ -840,7 +853,8 @@ class ConstrainedBeamSearchScorer(BeamScorer):
|
||||
for i, hypo in enumerate(best):
|
||||
decoded[i, : sent_lengths[i]] = hypo
|
||||
if sent_lengths[i] < sent_max_len:
|
||||
decoded[i, sent_lengths[i]] = eos_token_id
|
||||
# inserting only the first eos_token_id
|
||||
decoded[i, sent_lengths[i]] = eos_token_id[0]
|
||||
|
||||
return UserDict(
|
||||
{
|
||||
|
||||
@@ -142,8 +142,9 @@ class GenerationConfig(PushToHubMixin):
|
||||
The id of the token to force as the first generated token after the `decoder_start_token_id`. Useful for
|
||||
multilingual models like [mBART](../model_doc/mbart) where the first generated token needs to be the target
|
||||
language token.
|
||||
forced_eos_token_id (`int`, *optional*, defaults to `model.config.forced_eos_token_id`):
|
||||
The id of the token to force as the last generated token when `max_length` is reached.
|
||||
forced_eos_token_id (`Union[int, List[int]]`, *optional*, defaults to `model.config.forced_eos_token_id`):
|
||||
The id of the token to force as the last generated token when `max_length` is reached. Optionally, use a
|
||||
list to set multiple *end-of-sequence* tokens.
|
||||
remove_invalid_values (`bool`, *optional*, defaults to `model.config.remove_invalid_values`):
|
||||
Whether to remove possible *nan* and *inf* outputs of the model to prevent the generation method to crash.
|
||||
Note that using `remove_invalid_values` can slow down generation.
|
||||
@@ -152,10 +153,10 @@ class GenerationConfig(PushToHubMixin):
|
||||
generated. The tuple shall consist of: `(start_index, decay_factor)` where `start_index` indicates where
|
||||
penalty starts and `decay_factor` represents the factor of exponential decay
|
||||
suppress_tokens (`List[int]`, *optional*):
|
||||
A list of tokens that will be supressed at generation. The `SupressTokens` logit processor will set their
|
||||
A list of tokens that will be suppressed at generation. The `SupressTokens` logit processor will set their
|
||||
log probs to `-inf` so that they are not sampled.
|
||||
begin_suppress_tokens (`List[int]`, *optional*):
|
||||
A list of tokens that will be supressed at the begining of the generation. The `SupressBeginTokens` logit
|
||||
A list of tokens that will be suppressed at the beginning of the generation. The `SupressBeginTokens` logit
|
||||
processor will set their log probs to `-inf` so that they are not sampled.
|
||||
forced_decoder_ids (`List[List[int]]`, *optional*):
|
||||
A list of pairs of integers which indicates a mapping from generation indices to token indices that will be
|
||||
@@ -183,8 +184,8 @@ class GenerationConfig(PushToHubMixin):
|
||||
The id of the *padding* token.
|
||||
bos_token_id (`int`, *optional*):
|
||||
The id of the *beginning-of-sequence* token.
|
||||
eos_token_id (`int`, *optional*):
|
||||
The id of the *end-of-sequence* token.
|
||||
eos_token_id (`Union[int, List[int]]`, *optional*):
|
||||
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
|
||||
|
||||
> Generation parameters exclusive to encoder-decoder models
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
|
||||
import inspect
|
||||
import math
|
||||
from typing import Callable, Iterable, List, Optional, Tuple
|
||||
from typing import Callable, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -100,16 +100,18 @@ class MinLengthLogitsProcessor(LogitsProcessor):
|
||||
Args:
|
||||
min_length (`int`):
|
||||
The minimum length below which the score of `eos_token_id` is set to `-float("Inf")`.
|
||||
eos_token_id (`int`):
|
||||
The id of the *end-of-sequence* token.
|
||||
eos_token_id (`Union[int, List[int]]`):
|
||||
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
|
||||
"""
|
||||
|
||||
def __init__(self, min_length: int, eos_token_id: int):
|
||||
def __init__(self, min_length: int, eos_token_id: Union[int, List[int]]):
|
||||
if not isinstance(min_length, int) or min_length < 0:
|
||||
raise ValueError(f"`min_length` has to be a positive integer, but is {min_length}")
|
||||
|
||||
if not isinstance(eos_token_id, int) or eos_token_id < 0:
|
||||
raise ValueError(f"`eos_token_id` has to be a positive integer, but is {eos_token_id}")
|
||||
if isinstance(eos_token_id, int):
|
||||
eos_token_id = [eos_token_id]
|
||||
if not all([isinstance(i, int) for i in eos_token_id]) or any([i < 0 for i in eos_token_id]):
|
||||
raise ValueError(f"`eos_token_id` has to be a list of positive integers, but is {eos_token_id}")
|
||||
|
||||
self.min_length = min_length
|
||||
self.eos_token_id = eos_token_id
|
||||
@@ -117,7 +119,8 @@ class MinLengthLogitsProcessor(LogitsProcessor):
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
cur_len = input_ids.shape[-1]
|
||||
if cur_len < self.min_length:
|
||||
scores[:, self.eos_token_id] = -float("inf")
|
||||
for i in self.eos_token_id:
|
||||
scores[:, i] = -float("inf")
|
||||
return scores
|
||||
|
||||
|
||||
@@ -431,11 +434,11 @@ class NoBadWordsLogitsProcessor(LogitsProcessor):
|
||||
List of list of token ids that are not allowed to be generated. In order to get the token ids of the words
|
||||
that should not appear in the generated text, use `tokenizer(bad_words, add_prefix_space=True,
|
||||
add_special_tokens=False).input_ids`.
|
||||
eos_token_id (`int`):
|
||||
The id of the *end-of-sequence* token.
|
||||
eos_token_id (`Union[int, List[int]]`):
|
||||
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
|
||||
"""
|
||||
|
||||
def __init__(self, bad_words_ids: List[List[int]], eos_token_id: int):
|
||||
def __init__(self, bad_words_ids: List[List[int]], eos_token_id: Union[int, List[int]]):
|
||||
|
||||
if not isinstance(bad_words_ids, List) or len(bad_words_ids) == 0:
|
||||
raise ValueError(f"`bad_words_ids` has to be a non-empty list, but is {bad_words_ids}.")
|
||||
@@ -449,7 +452,14 @@ class NoBadWordsLogitsProcessor(LogitsProcessor):
|
||||
f"Each list in `bad_words_ids` has to be a list of positive integers, but is {bad_words_ids}."
|
||||
)
|
||||
|
||||
bad_words_ids = list(filter(lambda bad_token_seq: bad_token_seq != [eos_token_id], bad_words_ids))
|
||||
if eos_token_id is None:
|
||||
eos_token_id = []
|
||||
if isinstance(eos_token_id, int):
|
||||
eos_token_id = [eos_token_id]
|
||||
|
||||
bad_words_ids = list(
|
||||
filter(lambda bad_token_seq: all([bad_token_seq != [i] for i in eos_token_id]), bad_words_ids)
|
||||
)
|
||||
self.bad_words_id_length_1 = []
|
||||
self.bad_words_id_length_greater_than_1 = []
|
||||
for word in bad_words_ids:
|
||||
@@ -664,20 +674,24 @@ class ForcedEOSTokenLogitsProcessor(LogitsProcessor):
|
||||
Args:
|
||||
max_length (`int`):
|
||||
The maximum length of the sequence to be generated.
|
||||
eos_token_id (`int`):
|
||||
The id of the token to force as the last generated token when `max_length` is reached.
|
||||
eos_token_id (`Union[int, List[int]]`):
|
||||
The id of the token to force as the last generated token when `max_length` is reached. Optionally, use a
|
||||
list to set multiple *end-of-sequence* tokens.
|
||||
"""
|
||||
|
||||
def __init__(self, max_length: int, eos_token_id: int):
|
||||
def __init__(self, max_length: int, eos_token_id: Union[int, List[int]]):
|
||||
self.max_length = max_length
|
||||
if isinstance(eos_token_id, int):
|
||||
eos_token_id = [eos_token_id]
|
||||
self.eos_token_id = eos_token_id
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
cur_len = input_ids.shape[-1]
|
||||
if cur_len == self.max_length - 1:
|
||||
num_tokens = scores.shape[1]
|
||||
scores[:, [i for i in range(num_tokens) if i != self.eos_token_id]] = -float("inf")
|
||||
scores[:, self.eos_token_id] = 0
|
||||
scores[:, [i for i in range(num_tokens) if i not in self.eos_token_id]] = -float("inf")
|
||||
for i in self.eos_token_id:
|
||||
scores[:, i] = 0
|
||||
return scores
|
||||
|
||||
|
||||
@@ -707,23 +721,26 @@ class ExponentialDecayLengthPenalty(LogitsProcessor):
|
||||
exponential_decay_length_penalty (`tuple(int, float)`, *optional*):
|
||||
This tuple shall consist of: `(start_index, decay_factor)` where `start_index` indicates where penalty
|
||||
starts and `decay_factor` represents the factor of exponential decay
|
||||
eos_token_id (`int`):
|
||||
The id of the *end-of-sequence* token.
|
||||
eos_token_id (`Union[int, List[int]]`):
|
||||
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
|
||||
input_ids_seq_length (`int`):
|
||||
The length of the input sequence.
|
||||
"""
|
||||
|
||||
def __init__(self, exponential_decay_length_penalty: Tuple, eos_token_id: int, input_ids_seq_length: int):
|
||||
def __init__(
|
||||
self, exponential_decay_length_penalty: Tuple, eos_token_id: Union[int, List[int]], input_ids_seq_length: int
|
||||
):
|
||||
self.regulation_start = exponential_decay_length_penalty[0] + input_ids_seq_length
|
||||
self.regulation_factor = exponential_decay_length_penalty[1]
|
||||
if isinstance(eos_token_id, int):
|
||||
eos_token_id = [eos_token_id]
|
||||
self.eos_token_id = eos_token_id
|
||||
|
||||
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.FloatTensor:
|
||||
cur_len = input_ids.shape[-1]
|
||||
if cur_len > self.regulation_start:
|
||||
scores[:, self.eos_token_id] = scores[:, self.eos_token_id] * pow(
|
||||
self.regulation_factor, cur_len - self.regulation_start
|
||||
)
|
||||
for i in self.eos_token_id:
|
||||
scores[:, i] = scores[:, i] * pow(self.regulation_factor, cur_len - self.regulation_start)
|
||||
return scores
|
||||
|
||||
|
||||
|
||||
@@ -575,11 +575,13 @@ class GenerationMixin:
|
||||
self,
|
||||
inputs: torch.Tensor,
|
||||
pad_token_id: Optional[int],
|
||||
eos_token_id: Optional[int],
|
||||
eos_token_id: Optional[Union[int, List[int]]],
|
||||
) -> torch.LongTensor:
|
||||
is_input_ids = len(inputs.shape) == 2 and inputs.dtype in [torch.int, torch.long]
|
||||
is_pad_token_in_inputs = (pad_token_id is not None) and (pad_token_id in inputs)
|
||||
is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (pad_token_id != eos_token_id)
|
||||
if isinstance(eos_token_id, int):
|
||||
eos_token_id = [eos_token_id]
|
||||
is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (pad_token_id not in eos_token_id)
|
||||
|
||||
# Check if input is input_ids and padded -> only then is attention_mask defined
|
||||
if is_input_ids and is_pad_token_in_inputs and is_pad_token_not_equal_to_eos_token_id:
|
||||
@@ -902,7 +904,7 @@ class GenerationMixin:
|
||||
sequences: torch.Tensor,
|
||||
scores: Tuple[torch.Tensor],
|
||||
beam_indices: torch.Tensor,
|
||||
eos_token_id: int = None,
|
||||
eos_token_id: Union[int, List[int]] = None,
|
||||
):
|
||||
"""compute the transition probabilities of sequences given generation
|
||||
scores and beam indices"""
|
||||
@@ -1165,10 +1167,11 @@ class GenerationMixin:
|
||||
"The attention mask and the pad token id were not set. As a consequence, you may observe "
|
||||
"unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
|
||||
)
|
||||
logger.warning(
|
||||
f"Setting `pad_token_id` to `eos_token_id`:{generation_config.eos_token_id} for open-end generation."
|
||||
)
|
||||
generation_config.pad_token_id = generation_config.eos_token_id
|
||||
eos_token_id = generation_config.eos_token_id
|
||||
if isinstance(eos_token_id, list):
|
||||
eos_token_id = eos_token_id[0]
|
||||
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
|
||||
generation_config.pad_token_id = eos_token_id
|
||||
|
||||
# 3. Define model inputs
|
||||
# inputs_tensor has to be defined
|
||||
@@ -1624,7 +1627,7 @@ class GenerationMixin:
|
||||
logits_warper: Optional[LogitsProcessorList] = None,
|
||||
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
||||
pad_token_id: Optional[int] = None,
|
||||
eos_token_id: Optional[int] = None,
|
||||
eos_token_id: Optional[Union[int, List[int]]] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
output_scores: Optional[bool] = None,
|
||||
@@ -1708,6 +1711,8 @@ class GenerationMixin:
|
||||
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
||||
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
|
||||
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
|
||||
if isinstance(eos_token_id, int):
|
||||
eos_token_id = [eos_token_id]
|
||||
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
|
||||
output_attentions = (
|
||||
output_attentions if output_attentions is not None else self.generation_config.output_attentions
|
||||
@@ -1930,7 +1935,7 @@ class GenerationMixin:
|
||||
|
||||
# if eos_token was found in one sentence, set sentence to finished
|
||||
if eos_token_id is not None:
|
||||
unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long())
|
||||
unfinished_sequences = unfinished_sequences.mul((sum(next_tokens != i for i in eos_token_id)).long())
|
||||
|
||||
# stop when each sentence is finished, or if we exceed the maximum length
|
||||
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
|
||||
@@ -1967,7 +1972,7 @@ class GenerationMixin:
|
||||
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
||||
max_length: Optional[int] = None,
|
||||
pad_token_id: Optional[int] = None,
|
||||
eos_token_id: Optional[int] = None,
|
||||
eos_token_id: Optional[Union[int, List[int]]] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
output_scores: Optional[bool] = None,
|
||||
@@ -2067,6 +2072,8 @@ class GenerationMixin:
|
||||
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
|
||||
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
|
||||
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
|
||||
if isinstance(eos_token_id, int):
|
||||
eos_token_id = [eos_token_id]
|
||||
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
|
||||
output_attentions = (
|
||||
output_attentions if output_attentions is not None else self.generation_config.output_attentions
|
||||
@@ -2162,7 +2169,7 @@ class GenerationMixin:
|
||||
|
||||
# if eos_token was found in one sentence, set sentence to finished
|
||||
if eos_token_id is not None:
|
||||
unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long())
|
||||
unfinished_sequences = unfinished_sequences.mul((sum(next_tokens != i for i in eos_token_id)).long())
|
||||
|
||||
# stop when each sentence is finished, or if we exceed the maximum length
|
||||
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
|
||||
@@ -2200,7 +2207,7 @@ class GenerationMixin:
|
||||
logits_warper: Optional[LogitsProcessorList] = None,
|
||||
max_length: Optional[int] = None,
|
||||
pad_token_id: Optional[int] = None,
|
||||
eos_token_id: Optional[int] = None,
|
||||
eos_token_id: Optional[Union[int, List[int]]] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
output_scores: Optional[bool] = None,
|
||||
@@ -2320,6 +2327,8 @@ class GenerationMixin:
|
||||
logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
|
||||
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
|
||||
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
|
||||
if isinstance(eos_token_id, int):
|
||||
eos_token_id = [eos_token_id]
|
||||
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
|
||||
output_attentions = (
|
||||
output_attentions if output_attentions is not None else self.generation_config.output_attentions
|
||||
@@ -2418,7 +2427,7 @@ class GenerationMixin:
|
||||
|
||||
# if eos_token was found in one sentence, set sentence to finished
|
||||
if eos_token_id is not None:
|
||||
unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long())
|
||||
unfinished_sequences = unfinished_sequences.mul((sum(next_tokens != i for i in eos_token_id)).long())
|
||||
|
||||
# stop when each sentence is finished, or if we exceed the maximum length
|
||||
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
|
||||
@@ -2456,7 +2465,7 @@ class GenerationMixin:
|
||||
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
||||
max_length: Optional[int] = None,
|
||||
pad_token_id: Optional[int] = None,
|
||||
eos_token_id: Optional[int] = None,
|
||||
eos_token_id: Optional[Union[int, List[int]]] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
output_scores: Optional[bool] = None,
|
||||
@@ -2576,6 +2585,8 @@ class GenerationMixin:
|
||||
warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning)
|
||||
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
|
||||
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
|
||||
if isinstance(eos_token_id, int):
|
||||
eos_token_id = [eos_token_id]
|
||||
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
|
||||
output_attentions = (
|
||||
output_attentions if output_attentions is not None else self.generation_config.output_attentions
|
||||
@@ -2770,7 +2781,7 @@ class GenerationMixin:
|
||||
logits_warper: Optional[LogitsProcessorList] = None,
|
||||
max_length: Optional[int] = None,
|
||||
pad_token_id: Optional[int] = None,
|
||||
eos_token_id: Optional[int] = None,
|
||||
eos_token_id: Optional[Union[int, List[int]]] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
output_scores: Optional[bool] = None,
|
||||
@@ -2900,6 +2911,8 @@ class GenerationMixin:
|
||||
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
|
||||
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
|
||||
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
|
||||
if isinstance(eos_token_id, int):
|
||||
eos_token_id = [eos_token_id]
|
||||
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
|
||||
output_attentions = (
|
||||
output_attentions if output_attentions is not None else self.generation_config.output_attentions
|
||||
@@ -3089,7 +3102,7 @@ class GenerationMixin:
|
||||
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
||||
max_length: Optional[int] = None,
|
||||
pad_token_id: Optional[int] = None,
|
||||
eos_token_id: Optional[int] = None,
|
||||
eos_token_id: Optional[Union[int, List[int]]] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
output_scores: Optional[bool] = None,
|
||||
@@ -3213,6 +3226,8 @@ class GenerationMixin:
|
||||
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
|
||||
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
|
||||
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
|
||||
if isinstance(eos_token_id, int):
|
||||
eos_token_id = [eos_token_id]
|
||||
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
|
||||
output_attentions = (
|
||||
output_attentions if output_attentions is not None else self.generation_config.output_attentions
|
||||
@@ -3455,7 +3470,7 @@ class GenerationMixin:
|
||||
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
||||
max_length: Optional[int] = None,
|
||||
pad_token_id: Optional[int] = None,
|
||||
eos_token_id: Optional[int] = None,
|
||||
eos_token_id: Optional[Union[int, List[int]]] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
output_scores: Optional[bool] = None,
|
||||
@@ -3586,6 +3601,8 @@ class GenerationMixin:
|
||||
warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning)
|
||||
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
|
||||
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
|
||||
if isinstance(eos_token_id, int):
|
||||
eos_token_id = [eos_token_id]
|
||||
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
|
||||
output_attentions = (
|
||||
output_attentions if output_attentions is not None else self.generation_config.output_attentions
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
import inspect
|
||||
import unittest
|
||||
|
||||
from transformers import is_torch_available
|
||||
from transformers import is_torch_available, pipeline
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
|
||||
from ..test_modeling_common import floats_tensor, ids_tensor
|
||||
@@ -39,7 +39,6 @@ if is_torch_available():
|
||||
SpeechEncoderDecoderModel,
|
||||
T5ForConditionalGeneration,
|
||||
VisionEncoderDecoderModel,
|
||||
pipeline,
|
||||
top_k_top_p_filtering,
|
||||
)
|
||||
from transformers.generation import (
|
||||
@@ -91,8 +90,9 @@ class GenerationTesterMixin:
|
||||
max_length = input_ids.shape[-1] + 3
|
||||
if config.eos_token_id is not None and config.pad_token_id is None:
|
||||
# hack to allow generate for models such as GPT2 as is done in `generate()`
|
||||
config.pad_token_id = config.eos_token_id
|
||||
|
||||
if isinstance(config.eos_token_id, int):
|
||||
config.eos_token_id = [config.eos_token_id]
|
||||
config.pad_token_id = config.eos_token_id[0]
|
||||
# TransfoXL has no attention mask
|
||||
if "transfoxl" in config.__class__.__name__.lower():
|
||||
attention_mask = None
|
||||
@@ -3025,3 +3025,100 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
# However, valid model_kwargs are accepted
|
||||
valid_model_kwargs = {"attention_mask": torch.zeros_like(input_ids)}
|
||||
model.generate(input_ids, **valid_model_kwargs)
|
||||
|
||||
def test_eos_token_id_int_and_list_greedy_search(self):
|
||||
generation_kwargs = {
|
||||
"do_sample": False,
|
||||
"num_beams": 1,
|
||||
}
|
||||
expectation = 13
|
||||
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
text = """Hello, my dog is cute and"""
|
||||
tokens = tokenizer(text, return_tensors="pt")
|
||||
|
||||
model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
eos_token_id = 873
|
||||
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
|
||||
self.assertTrue(expectation == len(generated_tokens[0]))
|
||||
|
||||
torch.manual_seed(0)
|
||||
eos_token_id = [873]
|
||||
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
|
||||
self.assertTrue(expectation == len(generated_tokens[0]))
|
||||
|
||||
def test_eos_token_id_int_and_list_contrastive_search(self):
|
||||
generation_kwargs = {
|
||||
"do_sample": False,
|
||||
"num_beams": 1,
|
||||
"penalty_alpha": 0.6,
|
||||
"top_k": 4,
|
||||
}
|
||||
expectation = 17
|
||||
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
text = """Hello, my dog is cute and"""
|
||||
tokens = tokenizer(text, return_tensors="pt")
|
||||
|
||||
model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
eos_token_id = 225
|
||||
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
|
||||
self.assertTrue(expectation == len(generated_tokens[0]))
|
||||
|
||||
torch.manual_seed(0)
|
||||
eos_token_id = [225]
|
||||
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
|
||||
self.assertTrue(expectation == len(generated_tokens[0]))
|
||||
|
||||
def test_eos_token_id_int_and_list_top_k_top_sampling(self):
|
||||
generation_kwargs = {
|
||||
"do_sample": True,
|
||||
"num_beams": 1,
|
||||
"top_p": 0.7,
|
||||
"top_k": 10,
|
||||
"temperature": 0.7,
|
||||
}
|
||||
expectation = 15
|
||||
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
text = """Hello, my dog is cute and"""
|
||||
tokens = tokenizer(text, return_tensors="pt")
|
||||
|
||||
model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
eos_token_id = 846
|
||||
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
|
||||
self.assertTrue(expectation == len(generated_tokens[0]))
|
||||
|
||||
torch.manual_seed(0)
|
||||
eos_token_id = [846]
|
||||
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
|
||||
self.assertTrue(expectation == len(generated_tokens[0]))
|
||||
|
||||
def test_eos_token_id_int_and_list_beam_search(self):
|
||||
generation_kwargs = {
|
||||
"do_sample": False,
|
||||
"num_beams": 3,
|
||||
}
|
||||
expectation = 13
|
||||
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
text = """Hello, my dog is cute and"""
|
||||
tokens = tokenizer(text, return_tensors="pt")
|
||||
|
||||
model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
eos_token_id = 873
|
||||
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
|
||||
self.assertTrue(expectation == len(generated_tokens[0]))
|
||||
|
||||
torch.manual_seed(0)
|
||||
eos_token_id = [873]
|
||||
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
|
||||
self.assertTrue(expectation == len(generated_tokens[0]))
|
||||
|
||||
Reference in New Issue
Block a user