Add custom stop token ids for generation (#20727)
* Add StopIdStoppingCriteria * add a working test for stop id criteria * add to global scope * add stop_ids to generate * add pipeline test * use tokenizer encode in test * add test to generation utils * reformat * fixup * make-fix-copies * rename to stop_token_id * use stop_tokens instead * add to text to text generation * make fixup * make repo-consistency * Add support for list of ints for eos_token_id inside generation/utils.py * Instead of having if elses, cast the eos_token_id into a List[int] * Add List[int] support for logits_process.py * add List[int] for beam_search.py * add List[int] for forced_eos_token_id * revert stop token id stopping criteria changes * make fixup * fix tests * add eos_token_id to generation/utils.py and added tests test_utils.py * add eos_token_id type hints and fix for pad tokens * add comments * remove some prints and remove forced false test * fix * put back test_stop_sequence_stopping_criteria * remove unused import and make fixup * add a none check * update docstring * add more docstring for list ints * make fixup
This commit is contained in:
@@ -16,7 +16,7 @@
|
|||||||
import warnings
|
import warnings
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections import UserDict
|
from collections import UserDict
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@@ -42,8 +42,8 @@ PROCESS_INPUTS_DOCSTRING = r"""
|
|||||||
Beam indices indicating to which beam hypothesis the `next_tokens` correspond.
|
Beam indices indicating to which beam hypothesis the `next_tokens` correspond.
|
||||||
pad_token_id (`int`, *optional*):
|
pad_token_id (`int`, *optional*):
|
||||||
The id of the *padding* token.
|
The id of the *padding* token.
|
||||||
eos_token_id (`int`, *optional*):
|
eos_token_id (`Union[int, List[int]]`, *optional*):
|
||||||
The id of the *end-of-sequence* token.
|
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
`UserDict`: A dictionary composed of the fields as defined above:
|
`UserDict`: A dictionary composed of the fields as defined above:
|
||||||
@@ -74,8 +74,8 @@ FINALIZE_INPUTS_DOCSTRING = r"""
|
|||||||
The beam indices indicating to which beam the `final_beam_tokens` shall be added.
|
The beam indices indicating to which beam the `final_beam_tokens` shall be added.
|
||||||
pad_token_id (`int`, *optional*):
|
pad_token_id (`int`, *optional*):
|
||||||
The id of the *padding* token.
|
The id of the *padding* token.
|
||||||
eos_token_id (`int`, *optional*):
|
eos_token_id (`Union[int, List[int]]`, *optional*):
|
||||||
The id of the *end-of-sequence* token.
|
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
`torch.LongTensor` of shape `(batch_size * num_return_sequences, sequence_length)`: The generated sequences.
|
`torch.LongTensor` of shape `(batch_size * num_return_sequences, sequence_length)`: The generated sequences.
|
||||||
@@ -212,7 +212,7 @@ class BeamSearchScorer(BeamScorer):
|
|||||||
next_tokens: torch.LongTensor,
|
next_tokens: torch.LongTensor,
|
||||||
next_indices: torch.LongTensor,
|
next_indices: torch.LongTensor,
|
||||||
pad_token_id: Optional[int] = None,
|
pad_token_id: Optional[int] = None,
|
||||||
eos_token_id: Optional[int] = None,
|
eos_token_id: Optional[Union[int, List[int]]] = None,
|
||||||
beam_indices: Optional[torch.LongTensor] = None,
|
beam_indices: Optional[torch.LongTensor] = None,
|
||||||
) -> Tuple[torch.Tensor]:
|
) -> Tuple[torch.Tensor]:
|
||||||
cur_len = input_ids.shape[-1]
|
cur_len = input_ids.shape[-1]
|
||||||
@@ -234,6 +234,9 @@ class BeamSearchScorer(BeamScorer):
|
|||||||
next_beam_tokens = torch.zeros((batch_size, self.group_size), dtype=next_tokens.dtype, device=device)
|
next_beam_tokens = torch.zeros((batch_size, self.group_size), dtype=next_tokens.dtype, device=device)
|
||||||
next_beam_indices = torch.zeros((batch_size, self.group_size), dtype=next_indices.dtype, device=device)
|
next_beam_indices = torch.zeros((batch_size, self.group_size), dtype=next_indices.dtype, device=device)
|
||||||
|
|
||||||
|
if isinstance(eos_token_id, int):
|
||||||
|
eos_token_id = [eos_token_id]
|
||||||
|
|
||||||
for batch_idx, beam_hyp in enumerate(self._beam_hyps):
|
for batch_idx, beam_hyp in enumerate(self._beam_hyps):
|
||||||
if self._done[batch_idx]:
|
if self._done[batch_idx]:
|
||||||
if self.num_beams < len(beam_hyp):
|
if self.num_beams < len(beam_hyp):
|
||||||
@@ -253,7 +256,7 @@ class BeamSearchScorer(BeamScorer):
|
|||||||
):
|
):
|
||||||
batch_beam_idx = batch_idx * self.group_size + next_index
|
batch_beam_idx = batch_idx * self.group_size + next_index
|
||||||
# add to generated hypotheses if end of sentence
|
# add to generated hypotheses if end of sentence
|
||||||
if (eos_token_id is not None) and (next_token.item() == eos_token_id):
|
if (eos_token_id is not None) and (next_token.item() in eos_token_id):
|
||||||
# if beam_token does not belong to top num_beams tokens, it should not be added
|
# if beam_token does not belong to top num_beams tokens, it should not be added
|
||||||
is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.group_size
|
is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.group_size
|
||||||
if is_beam_token_worse_than_top_num_beams:
|
if is_beam_token_worse_than_top_num_beams:
|
||||||
@@ -307,11 +310,14 @@ class BeamSearchScorer(BeamScorer):
|
|||||||
final_beam_indices: torch.LongTensor,
|
final_beam_indices: torch.LongTensor,
|
||||||
max_length: int,
|
max_length: int,
|
||||||
pad_token_id: Optional[int] = None,
|
pad_token_id: Optional[int] = None,
|
||||||
eos_token_id: Optional[int] = None,
|
eos_token_id: Optional[Union[int, List[int]]] = None,
|
||||||
beam_indices: Optional[torch.LongTensor] = None,
|
beam_indices: Optional[torch.LongTensor] = None,
|
||||||
) -> Tuple[torch.LongTensor]:
|
) -> Tuple[torch.LongTensor]:
|
||||||
batch_size = len(self._beam_hyps)
|
batch_size = len(self._beam_hyps)
|
||||||
|
|
||||||
|
if isinstance(eos_token_id, int):
|
||||||
|
eos_token_id = [eos_token_id]
|
||||||
|
|
||||||
# finalize all open beam hypotheses and add to generated hypotheses
|
# finalize all open beam hypotheses and add to generated hypotheses
|
||||||
for batch_idx, beam_hyp in enumerate(self._beam_hyps):
|
for batch_idx, beam_hyp in enumerate(self._beam_hyps):
|
||||||
if self._done[batch_idx]:
|
if self._done[batch_idx]:
|
||||||
@@ -376,7 +382,8 @@ class BeamSearchScorer(BeamScorer):
|
|||||||
indices[i, : len(best_idx)] = torch.tensor(best_idx)
|
indices[i, : len(best_idx)] = torch.tensor(best_idx)
|
||||||
|
|
||||||
if sent_lengths[i] < sent_max_len:
|
if sent_lengths[i] < sent_max_len:
|
||||||
decoded[i, sent_lengths[i]] = eos_token_id
|
# inserting only the first eos_token_id
|
||||||
|
decoded[i, sent_lengths[i]] = eos_token_id[0]
|
||||||
|
|
||||||
return UserDict(
|
return UserDict(
|
||||||
{
|
{
|
||||||
@@ -491,7 +498,7 @@ class ConstrainedBeamSearchScorer(BeamScorer):
|
|||||||
next_indices: torch.LongTensor,
|
next_indices: torch.LongTensor,
|
||||||
scores_for_all_vocab: torch.FloatTensor,
|
scores_for_all_vocab: torch.FloatTensor,
|
||||||
pad_token_id: Optional[int] = None,
|
pad_token_id: Optional[int] = None,
|
||||||
eos_token_id: Optional[int] = None,
|
eos_token_id: Optional[Union[int, List[int]]] = None,
|
||||||
) -> Tuple[torch.Tensor]:
|
) -> Tuple[torch.Tensor]:
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
@@ -512,8 +519,8 @@ class ConstrainedBeamSearchScorer(BeamScorer):
|
|||||||
The scores of all tokens in the vocabulary for each of the beam hypotheses.
|
The scores of all tokens in the vocabulary for each of the beam hypotheses.
|
||||||
pad_token_id (`int`, *optional*):
|
pad_token_id (`int`, *optional*):
|
||||||
The id of the *padding* token.
|
The id of the *padding* token.
|
||||||
eos_token_id (`int`, *optional*):
|
eos_token_id (`Union[int, List[int]]`, *optional*):
|
||||||
The id of the *end-of-sequence* token.
|
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
`UserDict`: A dictionary composed of the fields as defined above:
|
`UserDict`: A dictionary composed of the fields as defined above:
|
||||||
@@ -549,6 +556,9 @@ class ConstrainedBeamSearchScorer(BeamScorer):
|
|||||||
next_beam_tokens = torch.zeros((batch_size, self.group_size), dtype=next_tokens.dtype, device=device)
|
next_beam_tokens = torch.zeros((batch_size, self.group_size), dtype=next_tokens.dtype, device=device)
|
||||||
next_beam_indices = torch.zeros((batch_size, self.group_size), dtype=next_indices.dtype, device=device)
|
next_beam_indices = torch.zeros((batch_size, self.group_size), dtype=next_indices.dtype, device=device)
|
||||||
|
|
||||||
|
if isinstance(eos_token_id, int):
|
||||||
|
eos_token_id = [eos_token_id]
|
||||||
|
|
||||||
for batch_idx, beam_hyp in enumerate(self._beam_hyps):
|
for batch_idx, beam_hyp in enumerate(self._beam_hyps):
|
||||||
if self._done[batch_idx]:
|
if self._done[batch_idx]:
|
||||||
if self.num_beams < len(beam_hyp):
|
if self.num_beams < len(beam_hyp):
|
||||||
@@ -568,7 +578,7 @@ class ConstrainedBeamSearchScorer(BeamScorer):
|
|||||||
):
|
):
|
||||||
batch_beam_idx = batch_idx * self.group_size + next_index
|
batch_beam_idx = batch_idx * self.group_size + next_index
|
||||||
# add to generated hypotheses if end of sentence
|
# add to generated hypotheses if end of sentence
|
||||||
if (eos_token_id is not None) and (next_token.item() == eos_token_id):
|
if (eos_token_id is not None) and (next_token.item() in eos_token_id):
|
||||||
|
|
||||||
# if beam_token does not belong to top num_beams tokens, it should not be added
|
# if beam_token does not belong to top num_beams tokens, it should not be added
|
||||||
is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.group_size
|
is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.group_size
|
||||||
@@ -773,10 +783,13 @@ class ConstrainedBeamSearchScorer(BeamScorer):
|
|||||||
final_beam_indices: torch.LongTensor,
|
final_beam_indices: torch.LongTensor,
|
||||||
max_length: int,
|
max_length: int,
|
||||||
pad_token_id: Optional[int] = None,
|
pad_token_id: Optional[int] = None,
|
||||||
eos_token_id: Optional[int] = None,
|
eos_token_id: Optional[Union[int, List[int]]] = None,
|
||||||
) -> Tuple[torch.LongTensor]:
|
) -> Tuple[torch.LongTensor]:
|
||||||
batch_size = len(self._beam_hyps)
|
batch_size = len(self._beam_hyps)
|
||||||
|
|
||||||
|
if isinstance(eos_token_id, int):
|
||||||
|
eos_token_id = [eos_token_id]
|
||||||
|
|
||||||
# finalize all open beam hypotheses and add to generated hypotheses
|
# finalize all open beam hypotheses and add to generated hypotheses
|
||||||
for batch_idx, beam_hyp in enumerate(self._beam_hyps):
|
for batch_idx, beam_hyp in enumerate(self._beam_hyps):
|
||||||
if self._done[batch_idx]:
|
if self._done[batch_idx]:
|
||||||
@@ -840,7 +853,8 @@ class ConstrainedBeamSearchScorer(BeamScorer):
|
|||||||
for i, hypo in enumerate(best):
|
for i, hypo in enumerate(best):
|
||||||
decoded[i, : sent_lengths[i]] = hypo
|
decoded[i, : sent_lengths[i]] = hypo
|
||||||
if sent_lengths[i] < sent_max_len:
|
if sent_lengths[i] < sent_max_len:
|
||||||
decoded[i, sent_lengths[i]] = eos_token_id
|
# inserting only the first eos_token_id
|
||||||
|
decoded[i, sent_lengths[i]] = eos_token_id[0]
|
||||||
|
|
||||||
return UserDict(
|
return UserDict(
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -142,8 +142,9 @@ class GenerationConfig(PushToHubMixin):
|
|||||||
The id of the token to force as the first generated token after the `decoder_start_token_id`. Useful for
|
The id of the token to force as the first generated token after the `decoder_start_token_id`. Useful for
|
||||||
multilingual models like [mBART](../model_doc/mbart) where the first generated token needs to be the target
|
multilingual models like [mBART](../model_doc/mbart) where the first generated token needs to be the target
|
||||||
language token.
|
language token.
|
||||||
forced_eos_token_id (`int`, *optional*, defaults to `model.config.forced_eos_token_id`):
|
forced_eos_token_id (`Union[int, List[int]]`, *optional*, defaults to `model.config.forced_eos_token_id`):
|
||||||
The id of the token to force as the last generated token when `max_length` is reached.
|
The id of the token to force as the last generated token when `max_length` is reached. Optionally, use a
|
||||||
|
list to set multiple *end-of-sequence* tokens.
|
||||||
remove_invalid_values (`bool`, *optional*, defaults to `model.config.remove_invalid_values`):
|
remove_invalid_values (`bool`, *optional*, defaults to `model.config.remove_invalid_values`):
|
||||||
Whether to remove possible *nan* and *inf* outputs of the model to prevent the generation method to crash.
|
Whether to remove possible *nan* and *inf* outputs of the model to prevent the generation method to crash.
|
||||||
Note that using `remove_invalid_values` can slow down generation.
|
Note that using `remove_invalid_values` can slow down generation.
|
||||||
@@ -152,10 +153,10 @@ class GenerationConfig(PushToHubMixin):
|
|||||||
generated. The tuple shall consist of: `(start_index, decay_factor)` where `start_index` indicates where
|
generated. The tuple shall consist of: `(start_index, decay_factor)` where `start_index` indicates where
|
||||||
penalty starts and `decay_factor` represents the factor of exponential decay
|
penalty starts and `decay_factor` represents the factor of exponential decay
|
||||||
suppress_tokens (`List[int]`, *optional*):
|
suppress_tokens (`List[int]`, *optional*):
|
||||||
A list of tokens that will be supressed at generation. The `SupressTokens` logit processor will set their
|
A list of tokens that will be suppressed at generation. The `SupressTokens` logit processor will set their
|
||||||
log probs to `-inf` so that they are not sampled.
|
log probs to `-inf` so that they are not sampled.
|
||||||
begin_suppress_tokens (`List[int]`, *optional*):
|
begin_suppress_tokens (`List[int]`, *optional*):
|
||||||
A list of tokens that will be supressed at the begining of the generation. The `SupressBeginTokens` logit
|
A list of tokens that will be suppressed at the beginning of the generation. The `SupressBeginTokens` logit
|
||||||
processor will set their log probs to `-inf` so that they are not sampled.
|
processor will set their log probs to `-inf` so that they are not sampled.
|
||||||
forced_decoder_ids (`List[List[int]]`, *optional*):
|
forced_decoder_ids (`List[List[int]]`, *optional*):
|
||||||
A list of pairs of integers which indicates a mapping from generation indices to token indices that will be
|
A list of pairs of integers which indicates a mapping from generation indices to token indices that will be
|
||||||
@@ -183,8 +184,8 @@ class GenerationConfig(PushToHubMixin):
|
|||||||
The id of the *padding* token.
|
The id of the *padding* token.
|
||||||
bos_token_id (`int`, *optional*):
|
bos_token_id (`int`, *optional*):
|
||||||
The id of the *beginning-of-sequence* token.
|
The id of the *beginning-of-sequence* token.
|
||||||
eos_token_id (`int`, *optional*):
|
eos_token_id (`Union[int, List[int]]`, *optional*):
|
||||||
The id of the *end-of-sequence* token.
|
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
|
||||||
|
|
||||||
> Generation parameters exclusive to encoder-decoder models
|
> Generation parameters exclusive to encoder-decoder models
|
||||||
|
|
||||||
|
|||||||
@@ -15,7 +15,7 @@
|
|||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
import math
|
import math
|
||||||
from typing import Callable, Iterable, List, Optional, Tuple
|
from typing import Callable, Iterable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@@ -100,16 +100,18 @@ class MinLengthLogitsProcessor(LogitsProcessor):
|
|||||||
Args:
|
Args:
|
||||||
min_length (`int`):
|
min_length (`int`):
|
||||||
The minimum length below which the score of `eos_token_id` is set to `-float("Inf")`.
|
The minimum length below which the score of `eos_token_id` is set to `-float("Inf")`.
|
||||||
eos_token_id (`int`):
|
eos_token_id (`Union[int, List[int]]`):
|
||||||
The id of the *end-of-sequence* token.
|
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, min_length: int, eos_token_id: int):
|
def __init__(self, min_length: int, eos_token_id: Union[int, List[int]]):
|
||||||
if not isinstance(min_length, int) or min_length < 0:
|
if not isinstance(min_length, int) or min_length < 0:
|
||||||
raise ValueError(f"`min_length` has to be a positive integer, but is {min_length}")
|
raise ValueError(f"`min_length` has to be a positive integer, but is {min_length}")
|
||||||
|
|
||||||
if not isinstance(eos_token_id, int) or eos_token_id < 0:
|
if isinstance(eos_token_id, int):
|
||||||
raise ValueError(f"`eos_token_id` has to be a positive integer, but is {eos_token_id}")
|
eos_token_id = [eos_token_id]
|
||||||
|
if not all([isinstance(i, int) for i in eos_token_id]) or any([i < 0 for i in eos_token_id]):
|
||||||
|
raise ValueError(f"`eos_token_id` has to be a list of positive integers, but is {eos_token_id}")
|
||||||
|
|
||||||
self.min_length = min_length
|
self.min_length = min_length
|
||||||
self.eos_token_id = eos_token_id
|
self.eos_token_id = eos_token_id
|
||||||
@@ -117,7 +119,8 @@ class MinLengthLogitsProcessor(LogitsProcessor):
|
|||||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
cur_len = input_ids.shape[-1]
|
cur_len = input_ids.shape[-1]
|
||||||
if cur_len < self.min_length:
|
if cur_len < self.min_length:
|
||||||
scores[:, self.eos_token_id] = -float("inf")
|
for i in self.eos_token_id:
|
||||||
|
scores[:, i] = -float("inf")
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
|
|
||||||
@@ -431,11 +434,11 @@ class NoBadWordsLogitsProcessor(LogitsProcessor):
|
|||||||
List of list of token ids that are not allowed to be generated. In order to get the token ids of the words
|
List of list of token ids that are not allowed to be generated. In order to get the token ids of the words
|
||||||
that should not appear in the generated text, use `tokenizer(bad_words, add_prefix_space=True,
|
that should not appear in the generated text, use `tokenizer(bad_words, add_prefix_space=True,
|
||||||
add_special_tokens=False).input_ids`.
|
add_special_tokens=False).input_ids`.
|
||||||
eos_token_id (`int`):
|
eos_token_id (`Union[int, List[int]]`):
|
||||||
The id of the *end-of-sequence* token.
|
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, bad_words_ids: List[List[int]], eos_token_id: int):
|
def __init__(self, bad_words_ids: List[List[int]], eos_token_id: Union[int, List[int]]):
|
||||||
|
|
||||||
if not isinstance(bad_words_ids, List) or len(bad_words_ids) == 0:
|
if not isinstance(bad_words_ids, List) or len(bad_words_ids) == 0:
|
||||||
raise ValueError(f"`bad_words_ids` has to be a non-empty list, but is {bad_words_ids}.")
|
raise ValueError(f"`bad_words_ids` has to be a non-empty list, but is {bad_words_ids}.")
|
||||||
@@ -449,7 +452,14 @@ class NoBadWordsLogitsProcessor(LogitsProcessor):
|
|||||||
f"Each list in `bad_words_ids` has to be a list of positive integers, but is {bad_words_ids}."
|
f"Each list in `bad_words_ids` has to be a list of positive integers, but is {bad_words_ids}."
|
||||||
)
|
)
|
||||||
|
|
||||||
bad_words_ids = list(filter(lambda bad_token_seq: bad_token_seq != [eos_token_id], bad_words_ids))
|
if eos_token_id is None:
|
||||||
|
eos_token_id = []
|
||||||
|
if isinstance(eos_token_id, int):
|
||||||
|
eos_token_id = [eos_token_id]
|
||||||
|
|
||||||
|
bad_words_ids = list(
|
||||||
|
filter(lambda bad_token_seq: all([bad_token_seq != [i] for i in eos_token_id]), bad_words_ids)
|
||||||
|
)
|
||||||
self.bad_words_id_length_1 = []
|
self.bad_words_id_length_1 = []
|
||||||
self.bad_words_id_length_greater_than_1 = []
|
self.bad_words_id_length_greater_than_1 = []
|
||||||
for word in bad_words_ids:
|
for word in bad_words_ids:
|
||||||
@@ -664,20 +674,24 @@ class ForcedEOSTokenLogitsProcessor(LogitsProcessor):
|
|||||||
Args:
|
Args:
|
||||||
max_length (`int`):
|
max_length (`int`):
|
||||||
The maximum length of the sequence to be generated.
|
The maximum length of the sequence to be generated.
|
||||||
eos_token_id (`int`):
|
eos_token_id (`Union[int, List[int]]`):
|
||||||
The id of the token to force as the last generated token when `max_length` is reached.
|
The id of the token to force as the last generated token when `max_length` is reached. Optionally, use a
|
||||||
|
list to set multiple *end-of-sequence* tokens.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, max_length: int, eos_token_id: int):
|
def __init__(self, max_length: int, eos_token_id: Union[int, List[int]]):
|
||||||
self.max_length = max_length
|
self.max_length = max_length
|
||||||
|
if isinstance(eos_token_id, int):
|
||||||
|
eos_token_id = [eos_token_id]
|
||||||
self.eos_token_id = eos_token_id
|
self.eos_token_id = eos_token_id
|
||||||
|
|
||||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
cur_len = input_ids.shape[-1]
|
cur_len = input_ids.shape[-1]
|
||||||
if cur_len == self.max_length - 1:
|
if cur_len == self.max_length - 1:
|
||||||
num_tokens = scores.shape[1]
|
num_tokens = scores.shape[1]
|
||||||
scores[:, [i for i in range(num_tokens) if i != self.eos_token_id]] = -float("inf")
|
scores[:, [i for i in range(num_tokens) if i not in self.eos_token_id]] = -float("inf")
|
||||||
scores[:, self.eos_token_id] = 0
|
for i in self.eos_token_id:
|
||||||
|
scores[:, i] = 0
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
|
|
||||||
@@ -707,23 +721,26 @@ class ExponentialDecayLengthPenalty(LogitsProcessor):
|
|||||||
exponential_decay_length_penalty (`tuple(int, float)`, *optional*):
|
exponential_decay_length_penalty (`tuple(int, float)`, *optional*):
|
||||||
This tuple shall consist of: `(start_index, decay_factor)` where `start_index` indicates where penalty
|
This tuple shall consist of: `(start_index, decay_factor)` where `start_index` indicates where penalty
|
||||||
starts and `decay_factor` represents the factor of exponential decay
|
starts and `decay_factor` represents the factor of exponential decay
|
||||||
eos_token_id (`int`):
|
eos_token_id (`Union[int, List[int]]`):
|
||||||
The id of the *end-of-sequence* token.
|
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
|
||||||
input_ids_seq_length (`int`):
|
input_ids_seq_length (`int`):
|
||||||
The length of the input sequence.
|
The length of the input sequence.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, exponential_decay_length_penalty: Tuple, eos_token_id: int, input_ids_seq_length: int):
|
def __init__(
|
||||||
|
self, exponential_decay_length_penalty: Tuple, eos_token_id: Union[int, List[int]], input_ids_seq_length: int
|
||||||
|
):
|
||||||
self.regulation_start = exponential_decay_length_penalty[0] + input_ids_seq_length
|
self.regulation_start = exponential_decay_length_penalty[0] + input_ids_seq_length
|
||||||
self.regulation_factor = exponential_decay_length_penalty[1]
|
self.regulation_factor = exponential_decay_length_penalty[1]
|
||||||
|
if isinstance(eos_token_id, int):
|
||||||
|
eos_token_id = [eos_token_id]
|
||||||
self.eos_token_id = eos_token_id
|
self.eos_token_id = eos_token_id
|
||||||
|
|
||||||
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.FloatTensor:
|
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.FloatTensor:
|
||||||
cur_len = input_ids.shape[-1]
|
cur_len = input_ids.shape[-1]
|
||||||
if cur_len > self.regulation_start:
|
if cur_len > self.regulation_start:
|
||||||
scores[:, self.eos_token_id] = scores[:, self.eos_token_id] * pow(
|
for i in self.eos_token_id:
|
||||||
self.regulation_factor, cur_len - self.regulation_start
|
scores[:, i] = scores[:, i] * pow(self.regulation_factor, cur_len - self.regulation_start)
|
||||||
)
|
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -575,11 +575,13 @@ class GenerationMixin:
|
|||||||
self,
|
self,
|
||||||
inputs: torch.Tensor,
|
inputs: torch.Tensor,
|
||||||
pad_token_id: Optional[int],
|
pad_token_id: Optional[int],
|
||||||
eos_token_id: Optional[int],
|
eos_token_id: Optional[Union[int, List[int]]],
|
||||||
) -> torch.LongTensor:
|
) -> torch.LongTensor:
|
||||||
is_input_ids = len(inputs.shape) == 2 and inputs.dtype in [torch.int, torch.long]
|
is_input_ids = len(inputs.shape) == 2 and inputs.dtype in [torch.int, torch.long]
|
||||||
is_pad_token_in_inputs = (pad_token_id is not None) and (pad_token_id in inputs)
|
is_pad_token_in_inputs = (pad_token_id is not None) and (pad_token_id in inputs)
|
||||||
is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (pad_token_id != eos_token_id)
|
if isinstance(eos_token_id, int):
|
||||||
|
eos_token_id = [eos_token_id]
|
||||||
|
is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (pad_token_id not in eos_token_id)
|
||||||
|
|
||||||
# Check if input is input_ids and padded -> only then is attention_mask defined
|
# Check if input is input_ids and padded -> only then is attention_mask defined
|
||||||
if is_input_ids and is_pad_token_in_inputs and is_pad_token_not_equal_to_eos_token_id:
|
if is_input_ids and is_pad_token_in_inputs and is_pad_token_not_equal_to_eos_token_id:
|
||||||
@@ -902,7 +904,7 @@ class GenerationMixin:
|
|||||||
sequences: torch.Tensor,
|
sequences: torch.Tensor,
|
||||||
scores: Tuple[torch.Tensor],
|
scores: Tuple[torch.Tensor],
|
||||||
beam_indices: torch.Tensor,
|
beam_indices: torch.Tensor,
|
||||||
eos_token_id: int = None,
|
eos_token_id: Union[int, List[int]] = None,
|
||||||
):
|
):
|
||||||
"""compute the transition probabilities of sequences given generation
|
"""compute the transition probabilities of sequences given generation
|
||||||
scores and beam indices"""
|
scores and beam indices"""
|
||||||
@@ -1165,10 +1167,11 @@ class GenerationMixin:
|
|||||||
"The attention mask and the pad token id were not set. As a consequence, you may observe "
|
"The attention mask and the pad token id were not set. As a consequence, you may observe "
|
||||||
"unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
|
"unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
|
||||||
)
|
)
|
||||||
logger.warning(
|
eos_token_id = generation_config.eos_token_id
|
||||||
f"Setting `pad_token_id` to `eos_token_id`:{generation_config.eos_token_id} for open-end generation."
|
if isinstance(eos_token_id, list):
|
||||||
)
|
eos_token_id = eos_token_id[0]
|
||||||
generation_config.pad_token_id = generation_config.eos_token_id
|
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
|
||||||
|
generation_config.pad_token_id = eos_token_id
|
||||||
|
|
||||||
# 3. Define model inputs
|
# 3. Define model inputs
|
||||||
# inputs_tensor has to be defined
|
# inputs_tensor has to be defined
|
||||||
@@ -1624,7 +1627,7 @@ class GenerationMixin:
|
|||||||
logits_warper: Optional[LogitsProcessorList] = None,
|
logits_warper: Optional[LogitsProcessorList] = None,
|
||||||
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
||||||
pad_token_id: Optional[int] = None,
|
pad_token_id: Optional[int] = None,
|
||||||
eos_token_id: Optional[int] = None,
|
eos_token_id: Optional[Union[int, List[int]]] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
output_scores: Optional[bool] = None,
|
output_scores: Optional[bool] = None,
|
||||||
@@ -1708,6 +1711,8 @@ class GenerationMixin:
|
|||||||
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
||||||
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
|
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
|
||||||
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
|
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
|
||||||
|
if isinstance(eos_token_id, int):
|
||||||
|
eos_token_id = [eos_token_id]
|
||||||
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
|
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
|
||||||
output_attentions = (
|
output_attentions = (
|
||||||
output_attentions if output_attentions is not None else self.generation_config.output_attentions
|
output_attentions if output_attentions is not None else self.generation_config.output_attentions
|
||||||
@@ -1930,7 +1935,7 @@ class GenerationMixin:
|
|||||||
|
|
||||||
# if eos_token was found in one sentence, set sentence to finished
|
# if eos_token was found in one sentence, set sentence to finished
|
||||||
if eos_token_id is not None:
|
if eos_token_id is not None:
|
||||||
unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long())
|
unfinished_sequences = unfinished_sequences.mul((sum(next_tokens != i for i in eos_token_id)).long())
|
||||||
|
|
||||||
# stop when each sentence is finished, or if we exceed the maximum length
|
# stop when each sentence is finished, or if we exceed the maximum length
|
||||||
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
|
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
|
||||||
@@ -1967,7 +1972,7 @@ class GenerationMixin:
|
|||||||
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
||||||
max_length: Optional[int] = None,
|
max_length: Optional[int] = None,
|
||||||
pad_token_id: Optional[int] = None,
|
pad_token_id: Optional[int] = None,
|
||||||
eos_token_id: Optional[int] = None,
|
eos_token_id: Optional[Union[int, List[int]]] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
output_scores: Optional[bool] = None,
|
output_scores: Optional[bool] = None,
|
||||||
@@ -2067,6 +2072,8 @@ class GenerationMixin:
|
|||||||
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
|
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
|
||||||
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
|
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
|
||||||
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
|
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
|
||||||
|
if isinstance(eos_token_id, int):
|
||||||
|
eos_token_id = [eos_token_id]
|
||||||
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
|
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
|
||||||
output_attentions = (
|
output_attentions = (
|
||||||
output_attentions if output_attentions is not None else self.generation_config.output_attentions
|
output_attentions if output_attentions is not None else self.generation_config.output_attentions
|
||||||
@@ -2162,7 +2169,7 @@ class GenerationMixin:
|
|||||||
|
|
||||||
# if eos_token was found in one sentence, set sentence to finished
|
# if eos_token was found in one sentence, set sentence to finished
|
||||||
if eos_token_id is not None:
|
if eos_token_id is not None:
|
||||||
unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long())
|
unfinished_sequences = unfinished_sequences.mul((sum(next_tokens != i for i in eos_token_id)).long())
|
||||||
|
|
||||||
# stop when each sentence is finished, or if we exceed the maximum length
|
# stop when each sentence is finished, or if we exceed the maximum length
|
||||||
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
|
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
|
||||||
@@ -2200,7 +2207,7 @@ class GenerationMixin:
|
|||||||
logits_warper: Optional[LogitsProcessorList] = None,
|
logits_warper: Optional[LogitsProcessorList] = None,
|
||||||
max_length: Optional[int] = None,
|
max_length: Optional[int] = None,
|
||||||
pad_token_id: Optional[int] = None,
|
pad_token_id: Optional[int] = None,
|
||||||
eos_token_id: Optional[int] = None,
|
eos_token_id: Optional[Union[int, List[int]]] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
output_scores: Optional[bool] = None,
|
output_scores: Optional[bool] = None,
|
||||||
@@ -2320,6 +2327,8 @@ class GenerationMixin:
|
|||||||
logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
|
logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
|
||||||
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
|
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
|
||||||
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
|
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
|
||||||
|
if isinstance(eos_token_id, int):
|
||||||
|
eos_token_id = [eos_token_id]
|
||||||
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
|
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
|
||||||
output_attentions = (
|
output_attentions = (
|
||||||
output_attentions if output_attentions is not None else self.generation_config.output_attentions
|
output_attentions if output_attentions is not None else self.generation_config.output_attentions
|
||||||
@@ -2418,7 +2427,7 @@ class GenerationMixin:
|
|||||||
|
|
||||||
# if eos_token was found in one sentence, set sentence to finished
|
# if eos_token was found in one sentence, set sentence to finished
|
||||||
if eos_token_id is not None:
|
if eos_token_id is not None:
|
||||||
unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long())
|
unfinished_sequences = unfinished_sequences.mul((sum(next_tokens != i for i in eos_token_id)).long())
|
||||||
|
|
||||||
# stop when each sentence is finished, or if we exceed the maximum length
|
# stop when each sentence is finished, or if we exceed the maximum length
|
||||||
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
|
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
|
||||||
@@ -2456,7 +2465,7 @@ class GenerationMixin:
|
|||||||
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
||||||
max_length: Optional[int] = None,
|
max_length: Optional[int] = None,
|
||||||
pad_token_id: Optional[int] = None,
|
pad_token_id: Optional[int] = None,
|
||||||
eos_token_id: Optional[int] = None,
|
eos_token_id: Optional[Union[int, List[int]]] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
output_scores: Optional[bool] = None,
|
output_scores: Optional[bool] = None,
|
||||||
@@ -2576,6 +2585,8 @@ class GenerationMixin:
|
|||||||
warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning)
|
warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning)
|
||||||
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
|
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
|
||||||
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
|
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
|
||||||
|
if isinstance(eos_token_id, int):
|
||||||
|
eos_token_id = [eos_token_id]
|
||||||
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
|
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
|
||||||
output_attentions = (
|
output_attentions = (
|
||||||
output_attentions if output_attentions is not None else self.generation_config.output_attentions
|
output_attentions if output_attentions is not None else self.generation_config.output_attentions
|
||||||
@@ -2770,7 +2781,7 @@ class GenerationMixin:
|
|||||||
logits_warper: Optional[LogitsProcessorList] = None,
|
logits_warper: Optional[LogitsProcessorList] = None,
|
||||||
max_length: Optional[int] = None,
|
max_length: Optional[int] = None,
|
||||||
pad_token_id: Optional[int] = None,
|
pad_token_id: Optional[int] = None,
|
||||||
eos_token_id: Optional[int] = None,
|
eos_token_id: Optional[Union[int, List[int]]] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
output_scores: Optional[bool] = None,
|
output_scores: Optional[bool] = None,
|
||||||
@@ -2900,6 +2911,8 @@ class GenerationMixin:
|
|||||||
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
|
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
|
||||||
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
|
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
|
||||||
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
|
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
|
||||||
|
if isinstance(eos_token_id, int):
|
||||||
|
eos_token_id = [eos_token_id]
|
||||||
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
|
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
|
||||||
output_attentions = (
|
output_attentions = (
|
||||||
output_attentions if output_attentions is not None else self.generation_config.output_attentions
|
output_attentions if output_attentions is not None else self.generation_config.output_attentions
|
||||||
@@ -3089,7 +3102,7 @@ class GenerationMixin:
|
|||||||
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
||||||
max_length: Optional[int] = None,
|
max_length: Optional[int] = None,
|
||||||
pad_token_id: Optional[int] = None,
|
pad_token_id: Optional[int] = None,
|
||||||
eos_token_id: Optional[int] = None,
|
eos_token_id: Optional[Union[int, List[int]]] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
output_scores: Optional[bool] = None,
|
output_scores: Optional[bool] = None,
|
||||||
@@ -3213,6 +3226,8 @@ class GenerationMixin:
|
|||||||
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
|
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
|
||||||
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
|
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
|
||||||
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
|
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
|
||||||
|
if isinstance(eos_token_id, int):
|
||||||
|
eos_token_id = [eos_token_id]
|
||||||
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
|
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
|
||||||
output_attentions = (
|
output_attentions = (
|
||||||
output_attentions if output_attentions is not None else self.generation_config.output_attentions
|
output_attentions if output_attentions is not None else self.generation_config.output_attentions
|
||||||
@@ -3455,7 +3470,7 @@ class GenerationMixin:
|
|||||||
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
||||||
max_length: Optional[int] = None,
|
max_length: Optional[int] = None,
|
||||||
pad_token_id: Optional[int] = None,
|
pad_token_id: Optional[int] = None,
|
||||||
eos_token_id: Optional[int] = None,
|
eos_token_id: Optional[Union[int, List[int]]] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
output_scores: Optional[bool] = None,
|
output_scores: Optional[bool] = None,
|
||||||
@@ -3586,6 +3601,8 @@ class GenerationMixin:
|
|||||||
warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning)
|
warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning)
|
||||||
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
|
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
|
||||||
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
|
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
|
||||||
|
if isinstance(eos_token_id, int):
|
||||||
|
eos_token_id = [eos_token_id]
|
||||||
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
|
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
|
||||||
output_attentions = (
|
output_attentions = (
|
||||||
output_attentions if output_attentions is not None else self.generation_config.output_attentions
|
output_attentions if output_attentions is not None else self.generation_config.output_attentions
|
||||||
|
|||||||
@@ -17,7 +17,7 @@
|
|||||||
import inspect
|
import inspect
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers import is_torch_available
|
from transformers import is_torch_available, pipeline
|
||||||
from transformers.testing_utils import require_torch, slow, torch_device
|
from transformers.testing_utils import require_torch, slow, torch_device
|
||||||
|
|
||||||
from ..test_modeling_common import floats_tensor, ids_tensor
|
from ..test_modeling_common import floats_tensor, ids_tensor
|
||||||
@@ -39,7 +39,6 @@ if is_torch_available():
|
|||||||
SpeechEncoderDecoderModel,
|
SpeechEncoderDecoderModel,
|
||||||
T5ForConditionalGeneration,
|
T5ForConditionalGeneration,
|
||||||
VisionEncoderDecoderModel,
|
VisionEncoderDecoderModel,
|
||||||
pipeline,
|
|
||||||
top_k_top_p_filtering,
|
top_k_top_p_filtering,
|
||||||
)
|
)
|
||||||
from transformers.generation import (
|
from transformers.generation import (
|
||||||
@@ -91,8 +90,9 @@ class GenerationTesterMixin:
|
|||||||
max_length = input_ids.shape[-1] + 3
|
max_length = input_ids.shape[-1] + 3
|
||||||
if config.eos_token_id is not None and config.pad_token_id is None:
|
if config.eos_token_id is not None and config.pad_token_id is None:
|
||||||
# hack to allow generate for models such as GPT2 as is done in `generate()`
|
# hack to allow generate for models such as GPT2 as is done in `generate()`
|
||||||
config.pad_token_id = config.eos_token_id
|
if isinstance(config.eos_token_id, int):
|
||||||
|
config.eos_token_id = [config.eos_token_id]
|
||||||
|
config.pad_token_id = config.eos_token_id[0]
|
||||||
# TransfoXL has no attention mask
|
# TransfoXL has no attention mask
|
||||||
if "transfoxl" in config.__class__.__name__.lower():
|
if "transfoxl" in config.__class__.__name__.lower():
|
||||||
attention_mask = None
|
attention_mask = None
|
||||||
@@ -3025,3 +3025,100 @@ class GenerationIntegrationTests(unittest.TestCase):
|
|||||||
# However, valid model_kwargs are accepted
|
# However, valid model_kwargs are accepted
|
||||||
valid_model_kwargs = {"attention_mask": torch.zeros_like(input_ids)}
|
valid_model_kwargs = {"attention_mask": torch.zeros_like(input_ids)}
|
||||||
model.generate(input_ids, **valid_model_kwargs)
|
model.generate(input_ids, **valid_model_kwargs)
|
||||||
|
|
||||||
|
def test_eos_token_id_int_and_list_greedy_search(self):
|
||||||
|
generation_kwargs = {
|
||||||
|
"do_sample": False,
|
||||||
|
"num_beams": 1,
|
||||||
|
}
|
||||||
|
expectation = 13
|
||||||
|
|
||||||
|
tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||||
|
text = """Hello, my dog is cute and"""
|
||||||
|
tokens = tokenizer(text, return_tensors="pt")
|
||||||
|
|
||||||
|
model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
|
||||||
|
|
||||||
|
torch.manual_seed(0)
|
||||||
|
eos_token_id = 873
|
||||||
|
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
|
||||||
|
self.assertTrue(expectation == len(generated_tokens[0]))
|
||||||
|
|
||||||
|
torch.manual_seed(0)
|
||||||
|
eos_token_id = [873]
|
||||||
|
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
|
||||||
|
self.assertTrue(expectation == len(generated_tokens[0]))
|
||||||
|
|
||||||
|
def test_eos_token_id_int_and_list_contrastive_search(self):
|
||||||
|
generation_kwargs = {
|
||||||
|
"do_sample": False,
|
||||||
|
"num_beams": 1,
|
||||||
|
"penalty_alpha": 0.6,
|
||||||
|
"top_k": 4,
|
||||||
|
}
|
||||||
|
expectation = 17
|
||||||
|
|
||||||
|
tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||||
|
text = """Hello, my dog is cute and"""
|
||||||
|
tokens = tokenizer(text, return_tensors="pt")
|
||||||
|
|
||||||
|
model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
|
||||||
|
|
||||||
|
torch.manual_seed(0)
|
||||||
|
eos_token_id = 225
|
||||||
|
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
|
||||||
|
self.assertTrue(expectation == len(generated_tokens[0]))
|
||||||
|
|
||||||
|
torch.manual_seed(0)
|
||||||
|
eos_token_id = [225]
|
||||||
|
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
|
||||||
|
self.assertTrue(expectation == len(generated_tokens[0]))
|
||||||
|
|
||||||
|
def test_eos_token_id_int_and_list_top_k_top_sampling(self):
|
||||||
|
generation_kwargs = {
|
||||||
|
"do_sample": True,
|
||||||
|
"num_beams": 1,
|
||||||
|
"top_p": 0.7,
|
||||||
|
"top_k": 10,
|
||||||
|
"temperature": 0.7,
|
||||||
|
}
|
||||||
|
expectation = 15
|
||||||
|
|
||||||
|
tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||||
|
text = """Hello, my dog is cute and"""
|
||||||
|
tokens = tokenizer(text, return_tensors="pt")
|
||||||
|
|
||||||
|
model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
|
||||||
|
|
||||||
|
torch.manual_seed(0)
|
||||||
|
eos_token_id = 846
|
||||||
|
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
|
||||||
|
self.assertTrue(expectation == len(generated_tokens[0]))
|
||||||
|
|
||||||
|
torch.manual_seed(0)
|
||||||
|
eos_token_id = [846]
|
||||||
|
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
|
||||||
|
self.assertTrue(expectation == len(generated_tokens[0]))
|
||||||
|
|
||||||
|
def test_eos_token_id_int_and_list_beam_search(self):
|
||||||
|
generation_kwargs = {
|
||||||
|
"do_sample": False,
|
||||||
|
"num_beams": 3,
|
||||||
|
}
|
||||||
|
expectation = 13
|
||||||
|
|
||||||
|
tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||||
|
text = """Hello, my dog is cute and"""
|
||||||
|
tokens = tokenizer(text, return_tensors="pt")
|
||||||
|
|
||||||
|
model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
|
||||||
|
|
||||||
|
torch.manual_seed(0)
|
||||||
|
eos_token_id = 873
|
||||||
|
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
|
||||||
|
self.assertTrue(expectation == len(generated_tokens[0]))
|
||||||
|
|
||||||
|
torch.manual_seed(0)
|
||||||
|
eos_token_id = [873]
|
||||||
|
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
|
||||||
|
self.assertTrue(expectation == len(generated_tokens[0]))
|
||||||
|
|||||||
Reference in New Issue
Block a user