Generate: consistently handle special tokens as tensors (#30624)
* tmp commit * [test_all] mvp * missing not * [test_all] final test fixes * fix musicgen_melody and rag * [test_all] empty commit * PR comments * Update src/transformers/generation/utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
@@ -218,8 +218,8 @@ class BeamSearchScorer(BeamScorer):
|
|||||||
next_scores: torch.FloatTensor,
|
next_scores: torch.FloatTensor,
|
||||||
next_tokens: torch.LongTensor,
|
next_tokens: torch.LongTensor,
|
||||||
next_indices: torch.LongTensor,
|
next_indices: torch.LongTensor,
|
||||||
pad_token_id: Optional[int] = None,
|
pad_token_id: Optional[Union[int, torch.Tensor]] = None,
|
||||||
eos_token_id: Optional[Union[int, List[int]]] = None,
|
eos_token_id: Optional[Union[int, List[int], torch.Tensor]] = None,
|
||||||
beam_indices: Optional[torch.LongTensor] = None,
|
beam_indices: Optional[torch.LongTensor] = None,
|
||||||
group_index: Optional[int] = 0,
|
group_index: Optional[int] = 0,
|
||||||
decoder_prompt_len: Optional[int] = 0,
|
decoder_prompt_len: Optional[int] = 0,
|
||||||
@@ -245,8 +245,10 @@ class BeamSearchScorer(BeamScorer):
|
|||||||
next_beam_tokens = torch.zeros((batch_size, self.group_size), dtype=next_tokens.dtype, device=device)
|
next_beam_tokens = torch.zeros((batch_size, self.group_size), dtype=next_tokens.dtype, device=device)
|
||||||
next_beam_indices = torch.zeros((batch_size, self.group_size), dtype=next_indices.dtype, device=device)
|
next_beam_indices = torch.zeros((batch_size, self.group_size), dtype=next_indices.dtype, device=device)
|
||||||
|
|
||||||
if isinstance(eos_token_id, int):
|
if eos_token_id is not None and not isinstance(eos_token_id, torch.Tensor):
|
||||||
eos_token_id = [eos_token_id]
|
if isinstance(eos_token_id, int):
|
||||||
|
eos_token_id = [eos_token_id]
|
||||||
|
eos_token_id = torch.tensor(eos_token_id)
|
||||||
|
|
||||||
for batch_idx in range(batch_size):
|
for batch_idx in range(batch_size):
|
||||||
batch_group_idx = batch_idx * self.num_beam_groups + group_index
|
batch_group_idx = batch_idx * self.num_beam_groups + group_index
|
||||||
@@ -322,15 +324,17 @@ class BeamSearchScorer(BeamScorer):
|
|||||||
final_beam_tokens: torch.LongTensor,
|
final_beam_tokens: torch.LongTensor,
|
||||||
final_beam_indices: torch.LongTensor,
|
final_beam_indices: torch.LongTensor,
|
||||||
max_length: int,
|
max_length: int,
|
||||||
pad_token_id: Optional[int] = None,
|
pad_token_id: Optional[Union[int, torch.Tensor]] = None,
|
||||||
eos_token_id: Optional[Union[int, List[int]]] = None,
|
eos_token_id: Optional[Union[int, List[int], torch.Tensor]] = None,
|
||||||
beam_indices: Optional[torch.LongTensor] = None,
|
beam_indices: Optional[torch.LongTensor] = None,
|
||||||
decoder_prompt_len: Optional[int] = 0,
|
decoder_prompt_len: Optional[int] = 0,
|
||||||
) -> Tuple[torch.LongTensor]:
|
) -> Tuple[torch.LongTensor]:
|
||||||
batch_size = len(self._beam_hyps) // self.num_beam_groups
|
batch_size = len(self._beam_hyps) // self.num_beam_groups
|
||||||
|
|
||||||
if isinstance(eos_token_id, int):
|
if eos_token_id is not None and not isinstance(eos_token_id, torch.Tensor):
|
||||||
eos_token_id = [eos_token_id]
|
if isinstance(eos_token_id, int):
|
||||||
|
eos_token_id = [eos_token_id]
|
||||||
|
eos_token_id = torch.tensor(eos_token_id)
|
||||||
|
|
||||||
# finalize all open beam hypotheses and add to generated hypotheses
|
# finalize all open beam hypotheses and add to generated hypotheses
|
||||||
for batch_group_idx, beam_hyp in enumerate(self._beam_hyps):
|
for batch_group_idx, beam_hyp in enumerate(self._beam_hyps):
|
||||||
@@ -513,8 +517,8 @@ class ConstrainedBeamSearchScorer(BeamScorer):
|
|||||||
next_tokens: torch.LongTensor,
|
next_tokens: torch.LongTensor,
|
||||||
next_indices: torch.LongTensor,
|
next_indices: torch.LongTensor,
|
||||||
scores_for_all_vocab: torch.FloatTensor,
|
scores_for_all_vocab: torch.FloatTensor,
|
||||||
pad_token_id: Optional[int] = None,
|
pad_token_id: Optional[Union[int, torch.Tensor]] = None,
|
||||||
eos_token_id: Optional[Union[int, List[int]]] = None,
|
eos_token_id: Optional[Union[int, List[int], torch.Tensor]] = None,
|
||||||
beam_indices: Optional[torch.LongTensor] = None,
|
beam_indices: Optional[torch.LongTensor] = None,
|
||||||
decoder_prompt_len: Optional[int] = 0,
|
decoder_prompt_len: Optional[int] = 0,
|
||||||
) -> Tuple[torch.Tensor]:
|
) -> Tuple[torch.Tensor]:
|
||||||
@@ -578,8 +582,10 @@ class ConstrainedBeamSearchScorer(BeamScorer):
|
|||||||
next_beam_tokens = torch.zeros((batch_size, self.group_size), dtype=next_tokens.dtype, device=device)
|
next_beam_tokens = torch.zeros((batch_size, self.group_size), dtype=next_tokens.dtype, device=device)
|
||||||
next_beam_indices = torch.zeros((batch_size, self.group_size), dtype=next_indices.dtype, device=device)
|
next_beam_indices = torch.zeros((batch_size, self.group_size), dtype=next_indices.dtype, device=device)
|
||||||
|
|
||||||
if isinstance(eos_token_id, int):
|
if eos_token_id is not None and not isinstance(eos_token_id, torch.Tensor):
|
||||||
eos_token_id = [eos_token_id]
|
if isinstance(eos_token_id, int):
|
||||||
|
eos_token_id = [eos_token_id]
|
||||||
|
eos_token_id = torch.tensor(eos_token_id)
|
||||||
|
|
||||||
for batch_idx, beam_hyp in enumerate(self._beam_hyps):
|
for batch_idx, beam_hyp in enumerate(self._beam_hyps):
|
||||||
if self._done[batch_idx]:
|
if self._done[batch_idx]:
|
||||||
@@ -811,15 +817,17 @@ class ConstrainedBeamSearchScorer(BeamScorer):
|
|||||||
final_beam_tokens: torch.LongTensor,
|
final_beam_tokens: torch.LongTensor,
|
||||||
final_beam_indices: torch.LongTensor,
|
final_beam_indices: torch.LongTensor,
|
||||||
max_length: int,
|
max_length: int,
|
||||||
pad_token_id: Optional[int] = None,
|
pad_token_id: Optional[Union[int, torch.Tensor]] = None,
|
||||||
eos_token_id: Optional[Union[int, List[int]]] = None,
|
eos_token_id: Optional[Union[int, List[int], torch.Tensor]] = None,
|
||||||
beam_indices: Optional[torch.LongTensor] = None,
|
beam_indices: Optional[torch.LongTensor] = None,
|
||||||
decoder_prompt_len: Optional[int] = 0,
|
decoder_prompt_len: Optional[int] = 0,
|
||||||
) -> Tuple[torch.LongTensor]:
|
) -> Tuple[torch.LongTensor]:
|
||||||
batch_size = len(self._beam_hyps)
|
batch_size = len(self._beam_hyps)
|
||||||
|
|
||||||
if isinstance(eos_token_id, int):
|
if eos_token_id is not None and not isinstance(eos_token_id, torch.Tensor):
|
||||||
eos_token_id = [eos_token_id]
|
if isinstance(eos_token_id, int):
|
||||||
|
eos_token_id = [eos_token_id]
|
||||||
|
eos_token_id = torch.tensor(eos_token_id)
|
||||||
|
|
||||||
# finalize all open beam hypotheses and add to generated hypotheses
|
# finalize all open beam hypotheses and add to generated hypotheses
|
||||||
for batch_idx, beam_hyp in enumerate(self._beam_hyps):
|
for batch_idx, beam_hyp in enumerate(self._beam_hyps):
|
||||||
|
|||||||
@@ -108,8 +108,8 @@ class MinLengthLogitsProcessor(LogitsProcessor):
|
|||||||
Args:
|
Args:
|
||||||
min_length (`int`):
|
min_length (`int`):
|
||||||
The minimum length below which the score of `eos_token_id` is set to `-float("Inf")`.
|
The minimum length below which the score of `eos_token_id` is set to `-float("Inf")`.
|
||||||
eos_token_id (`Union[int, List[int]]`):
|
eos_token_id (`Union[int, List[int], torch.Tensor]`):
|
||||||
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
|
The id(s) of the *end-of-sequence* token.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
|
|
||||||
@@ -137,14 +137,14 @@ class MinLengthLogitsProcessor(LogitsProcessor):
|
|||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, min_length: int, eos_token_id: Union[int, List[int]]):
|
def __init__(self, min_length: int, eos_token_id: Union[int, List[int], torch.Tensor]):
|
||||||
if not isinstance(min_length, int) or min_length < 0:
|
if not isinstance(min_length, int) or min_length < 0:
|
||||||
raise ValueError(f"`min_length` has to be a non-negative integer, but is {min_length}")
|
raise ValueError(f"`min_length` has to be a non-negative integer, but is {min_length}")
|
||||||
|
|
||||||
if isinstance(eos_token_id, int):
|
if not isinstance(eos_token_id, torch.Tensor):
|
||||||
eos_token_id = [eos_token_id]
|
if isinstance(eos_token_id, int):
|
||||||
if not all(isinstance(i, int) for i in eos_token_id) or any(i < 0 for i in eos_token_id):
|
eos_token_id = [eos_token_id]
|
||||||
logger.warning(f"`eos_token_id` has to be a list of positive integers, but is {eos_token_id}")
|
eos_token_id = torch.tensor(eos_token_id)
|
||||||
|
|
||||||
self.min_length = min_length
|
self.min_length = min_length
|
||||||
self.eos_token_id = eos_token_id
|
self.eos_token_id = eos_token_id
|
||||||
@@ -152,8 +152,8 @@ class MinLengthLogitsProcessor(LogitsProcessor):
|
|||||||
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
||||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
vocab_tensor = torch.arange(scores.shape[-1], device=scores.device)
|
vocab_tensor = torch.arange(scores.shape[-1], device=scores.device)
|
||||||
eos_token_id = torch.tensor(self.eos_token_id, device=scores.device)
|
self.eos_token_id = self.eos_token_id.to(scores.device)
|
||||||
eos_token_mask = torch.isin(vocab_tensor, eos_token_id)
|
eos_token_mask = torch.isin(vocab_tensor, self.eos_token_id)
|
||||||
scores_processed = scores.clone()
|
scores_processed = scores.clone()
|
||||||
if input_ids.shape[-1] < self.min_length:
|
if input_ids.shape[-1] < self.min_length:
|
||||||
scores_processed = torch.where(eos_token_mask, -math.inf, scores)
|
scores_processed = torch.where(eos_token_mask, -math.inf, scores)
|
||||||
@@ -171,8 +171,8 @@ class MinNewTokensLengthLogitsProcessor(LogitsProcessor):
|
|||||||
input length.
|
input length.
|
||||||
min_new_tokens (`int`):
|
min_new_tokens (`int`):
|
||||||
The minimum *new* tokens length below which the score of `eos_token_id` is set to `-float("Inf")`.
|
The minimum *new* tokens length below which the score of `eos_token_id` is set to `-float("Inf")`.
|
||||||
eos_token_id (`Union[int, List[int]]`):
|
eos_token_id (`Union[int, List[int], torch.Tensor]`):
|
||||||
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
|
The id(s) of the *end-of-sequence* token.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
|
|
||||||
@@ -195,7 +195,9 @@ class MinNewTokensLengthLogitsProcessor(LogitsProcessor):
|
|||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, prompt_length_to_skip: int, min_new_tokens: int, eos_token_id: Union[int, List[int]]):
|
def __init__(
|
||||||
|
self, prompt_length_to_skip: int, min_new_tokens: int, eos_token_id: Union[int, List[int], torch.Tensor]
|
||||||
|
):
|
||||||
for arg_name, arg_value in [
|
for arg_name, arg_value in [
|
||||||
("prompt_length_to_skip", prompt_length_to_skip),
|
("prompt_length_to_skip", prompt_length_to_skip),
|
||||||
("min_new_tokens", min_new_tokens),
|
("min_new_tokens", min_new_tokens),
|
||||||
@@ -203,10 +205,10 @@ class MinNewTokensLengthLogitsProcessor(LogitsProcessor):
|
|||||||
if not isinstance(arg_value, int) or arg_value < 0:
|
if not isinstance(arg_value, int) or arg_value < 0:
|
||||||
raise ValueError(f"`{arg_name}` has to be a positive integer, but is {arg_value}")
|
raise ValueError(f"`{arg_name}` has to be a positive integer, but is {arg_value}")
|
||||||
|
|
||||||
if isinstance(eos_token_id, int):
|
if not isinstance(eos_token_id, torch.Tensor):
|
||||||
eos_token_id = [eos_token_id]
|
if isinstance(eos_token_id, int):
|
||||||
if not all(isinstance(i, int) for i in eos_token_id) or any(i < 0 for i in eos_token_id):
|
eos_token_id = [eos_token_id]
|
||||||
logger.warning(f"`eos_token_id` has to be a list of positive integers, but is {eos_token_id}")
|
eos_token_id = torch.tensor(eos_token_id)
|
||||||
|
|
||||||
self.prompt_length_to_skip = prompt_length_to_skip
|
self.prompt_length_to_skip = prompt_length_to_skip
|
||||||
self.min_new_tokens = min_new_tokens
|
self.min_new_tokens = min_new_tokens
|
||||||
@@ -217,8 +219,8 @@ class MinNewTokensLengthLogitsProcessor(LogitsProcessor):
|
|||||||
new_tokens_length = input_ids.shape[-1] - self.prompt_length_to_skip
|
new_tokens_length = input_ids.shape[-1] - self.prompt_length_to_skip
|
||||||
scores_processed = scores.clone()
|
scores_processed = scores.clone()
|
||||||
vocab_tensor = torch.arange(scores.shape[-1], device=scores.device)
|
vocab_tensor = torch.arange(scores.shape[-1], device=scores.device)
|
||||||
eos_token_id = torch.tensor(self.eos_token_id, device=scores.device)
|
self.eos_token_id = self.eos_token_id.to(scores.device)
|
||||||
eos_token_mask = torch.isin(vocab_tensor, eos_token_id)
|
eos_token_mask = torch.isin(vocab_tensor, self.eos_token_id)
|
||||||
if new_tokens_length < self.min_new_tokens:
|
if new_tokens_length < self.min_new_tokens:
|
||||||
scores_processed = torch.where(eos_token_mask, -math.inf, scores)
|
scores_processed = torch.where(eos_token_mask, -math.inf, scores)
|
||||||
|
|
||||||
@@ -1195,8 +1197,8 @@ class NoBadWordsLogitsProcessor(SequenceBiasLogitsProcessor):
|
|||||||
Args:
|
Args:
|
||||||
bad_words_ids (`List[List[int]]`):
|
bad_words_ids (`List[List[int]]`):
|
||||||
List of list of token ids that are not allowed to be generated.
|
List of list of token ids that are not allowed to be generated.
|
||||||
eos_token_id (`Union[int, List[int]]`):
|
eos_token_id (`Union[int, List[int], torch.Tensor]`, *optional*):
|
||||||
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
|
The id(s) of the *end-of-sequence* token.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
|
|
||||||
@@ -1233,18 +1235,22 @@ class NoBadWordsLogitsProcessor(SequenceBiasLogitsProcessor):
|
|||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, bad_words_ids: List[List[int]], eos_token_id: Union[int, List[int]]):
|
def __init__(
|
||||||
|
self, bad_words_ids: List[List[int]], eos_token_id: Optional[Union[int, List[int], torch.Tensor]] = None
|
||||||
|
):
|
||||||
self.bad_word_ids = bad_words_ids
|
self.bad_word_ids = bad_words_ids
|
||||||
self._validate_arguments()
|
self._validate_arguments()
|
||||||
|
|
||||||
# Filter EOS token from bad_words_ids
|
# Filter EOS token from bad_words_ids
|
||||||
if eos_token_id is None:
|
if eos_token_id is not None:
|
||||||
eos_token_id = []
|
if not isinstance(eos_token_id, torch.Tensor):
|
||||||
if isinstance(eos_token_id, int):
|
if isinstance(eos_token_id, int):
|
||||||
eos_token_id = [eos_token_id]
|
eos_token_id = [eos_token_id]
|
||||||
bad_words_ids = list(
|
eos_token_id = torch.tensor(eos_token_id)
|
||||||
filter(lambda bad_token_seq: all(bad_token_seq != [i] for i in eos_token_id), bad_words_ids)
|
|
||||||
)
|
bad_words_ids = list(
|
||||||
|
filter(lambda bad_token_seq: all(bad_token_seq != [i] for i in eos_token_id), bad_words_ids)
|
||||||
|
)
|
||||||
|
|
||||||
# Forbidding a sequence is equivalent to setting its bias to -inf
|
# Forbidding a sequence is equivalent to setting its bias to -inf
|
||||||
sequence_bias = {tuple(sequence): float("-inf") for sequence in bad_words_ids}
|
sequence_bias = {tuple(sequence): float("-inf") for sequence in bad_words_ids}
|
||||||
@@ -1522,9 +1528,8 @@ class ForcedEOSTokenLogitsProcessor(LogitsProcessor):
|
|||||||
Args:
|
Args:
|
||||||
max_length (`int`):
|
max_length (`int`):
|
||||||
The maximum length of the sequence to be generated.
|
The maximum length of the sequence to be generated.
|
||||||
eos_token_id (`Union[int, List[int]]`):
|
eos_token_id (`Union[int, List[int], torch.Tensor]`):
|
||||||
The id of the token to force as the last generated token when `max_length` is reached. Optionally, use a
|
The id(s) of the *end-of-sequence* token.
|
||||||
list to set multiple *end-of-sequence* tokens.
|
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
|
|
||||||
@@ -1548,15 +1553,22 @@ class ForcedEOSTokenLogitsProcessor(LogitsProcessor):
|
|||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, max_length: int, eos_token_id: Union[int, List[int]]):
|
def __init__(self, max_length: int, eos_token_id: Union[int, List[int], torch.Tensor]):
|
||||||
self.max_length = max_length
|
self.max_length = max_length
|
||||||
if isinstance(eos_token_id, int):
|
|
||||||
eos_token_id = [eos_token_id]
|
if not isinstance(eos_token_id, torch.Tensor):
|
||||||
|
if isinstance(eos_token_id, int):
|
||||||
|
eos_token_id = [eos_token_id]
|
||||||
|
eos_token_id = torch.tensor(eos_token_id)
|
||||||
self.eos_token_id = eos_token_id
|
self.eos_token_id = eos_token_id
|
||||||
|
|
||||||
|
if torch.is_floating_point(eos_token_id) or (eos_token_id < 0).any():
|
||||||
|
raise ValueError(f"`eos_token_id` has to be a list of positive integers, but is {eos_token_id}")
|
||||||
|
|
||||||
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
||||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
cur_len = input_ids.shape[-1]
|
cur_len = input_ids.shape[-1]
|
||||||
|
self.eos_token_id = self.eos_token_id.to(scores.device)
|
||||||
scores_processed = scores
|
scores_processed = scores
|
||||||
if cur_len == self.max_length - 1:
|
if cur_len == self.max_length - 1:
|
||||||
scores_processed = torch.full_like(scores, -math.inf)
|
scores_processed = torch.full_like(scores, -math.inf)
|
||||||
@@ -1595,8 +1607,8 @@ class ExponentialDecayLengthPenalty(LogitsProcessor):
|
|||||||
exponential_decay_length_penalty (`tuple(int, float)`):
|
exponential_decay_length_penalty (`tuple(int, float)`):
|
||||||
This tuple shall consist of: `(start_index, decay_factor)` where `start_index` indicates where penalty
|
This tuple shall consist of: `(start_index, decay_factor)` where `start_index` indicates where penalty
|
||||||
starts and `decay_factor` represents the factor of exponential decay
|
starts and `decay_factor` represents the factor of exponential decay
|
||||||
eos_token_id (`Union[int, List[int]]`):
|
eos_token_id (`Union[int, List[int], torch.Tensor]`):
|
||||||
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
|
The id(s) of the *end-of-sequence* token.
|
||||||
input_ids_seq_length (`int`):
|
input_ids_seq_length (`int`):
|
||||||
The length of the input sequence.
|
The length of the input sequence.
|
||||||
|
|
||||||
@@ -1656,27 +1668,33 @@ class ExponentialDecayLengthPenalty(LogitsProcessor):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
exponential_decay_length_penalty: Tuple[int, float],
|
exponential_decay_length_penalty: Tuple[int, float],
|
||||||
eos_token_id: Union[int, List[int]],
|
eos_token_id: Union[int, List[int], torch.Tensor],
|
||||||
input_ids_seq_length: int,
|
input_ids_seq_length: int,
|
||||||
):
|
):
|
||||||
self.regulation_start = exponential_decay_length_penalty[0] + input_ids_seq_length
|
self.regulation_start = exponential_decay_length_penalty[0] + input_ids_seq_length
|
||||||
self.regulation_factor = exponential_decay_length_penalty[1]
|
self.regulation_factor = exponential_decay_length_penalty[1]
|
||||||
if isinstance(eos_token_id, int):
|
|
||||||
eos_token_id = [eos_token_id]
|
if not isinstance(eos_token_id, torch.Tensor):
|
||||||
|
if isinstance(eos_token_id, int):
|
||||||
|
eos_token_id = [eos_token_id]
|
||||||
|
eos_token_id = torch.tensor(eos_token_id)
|
||||||
self.eos_token_id = eos_token_id
|
self.eos_token_id = eos_token_id
|
||||||
|
|
||||||
|
if torch.is_floating_point(eos_token_id) or (eos_token_id < 0).any():
|
||||||
|
raise ValueError(f"`eos_token_id` has to be a list of positive integers, but is {eos_token_id}")
|
||||||
|
|
||||||
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
||||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
cur_len = input_ids.shape[-1]
|
cur_len = input_ids.shape[-1]
|
||||||
|
self.eos_token_id = self.eos_token_id.to(scores.device)
|
||||||
penalties = torch.zeros_like(scores)
|
penalties = torch.zeros_like(scores)
|
||||||
scores_processed = scores
|
scores_processed = scores
|
||||||
if cur_len > self.regulation_start:
|
if cur_len > self.regulation_start:
|
||||||
for i in self.eos_token_id:
|
penalty_idx = cur_len - self.regulation_start
|
||||||
penalty_idx = cur_len - self.regulation_start
|
# To support negative logits we compute the penalty of the absolute value and add to the original logit
|
||||||
# To support negative logits we compute the penalty of the absolute value and add to the original logit
|
penalty = torch.abs(scores[:, self.eos_token_id]) * (pow(self.regulation_factor, penalty_idx) - 1)
|
||||||
penalty = torch.abs(scores[:, i]) * (pow(self.regulation_factor, penalty_idx) - 1)
|
penalties[:, self.eos_token_id] = penalty
|
||||||
penalties[:, i] = penalty
|
scores_processed = scores + penalties
|
||||||
scores_processed = scores + penalties
|
|
||||||
return scores_processed
|
return scores_processed
|
||||||
|
|
||||||
|
|
||||||
@@ -1753,7 +1771,7 @@ class SuppressTokensAtBeginLogitsProcessor(LogitsProcessor):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, begin_suppress_tokens, begin_index):
|
def __init__(self, begin_suppress_tokens, begin_index):
|
||||||
self.begin_suppress_tokens = list(begin_suppress_tokens)
|
self.begin_suppress_tokens = torch.tensor(list(begin_suppress_tokens))
|
||||||
self.begin_index = begin_index
|
self.begin_index = begin_index
|
||||||
|
|
||||||
def set_begin_index(self, begin_index):
|
def set_begin_index(self, begin_index):
|
||||||
@@ -1762,8 +1780,8 @@ class SuppressTokensAtBeginLogitsProcessor(LogitsProcessor):
|
|||||||
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
||||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
vocab_tensor = torch.arange(scores.shape[-1], device=scores.device)
|
vocab_tensor = torch.arange(scores.shape[-1], device=scores.device)
|
||||||
begin_suppress_tokens = torch.tensor(self.begin_suppress_tokens, device=scores.device)
|
self.begin_suppress_tokens = self.begin_suppress_tokens.to(scores.device)
|
||||||
suppress_token_mask = torch.isin(vocab_tensor, begin_suppress_tokens)
|
suppress_token_mask = torch.isin(vocab_tensor, self.begin_suppress_tokens)
|
||||||
scores_processed = scores
|
scores_processed = scores
|
||||||
if input_ids.shape[-1] == self.begin_index:
|
if input_ids.shape[-1] == self.begin_index:
|
||||||
scores_processed = torch.where(suppress_token_mask, -float("inf"), scores)
|
scores_processed = torch.where(suppress_token_mask, -float("inf"), scores)
|
||||||
@@ -1801,13 +1819,13 @@ class SuppressTokensLogitsProcessor(LogitsProcessor):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, suppress_tokens):
|
def __init__(self, suppress_tokens):
|
||||||
self.suppress_tokens = list(suppress_tokens)
|
self.suppress_tokens = torch.tensor(list(suppress_tokens))
|
||||||
|
|
||||||
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
||||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
vocab_tensor = torch.arange(scores.shape[-1], device=scores.device)
|
vocab_tensor = torch.arange(scores.shape[-1], device=scores.device)
|
||||||
suppress_tokens = torch.tensor(self.suppress_tokens, device=scores.device)
|
self.suppress_tokens = self.suppress_tokens.to(scores.device)
|
||||||
suppress_token_mask = torch.isin(vocab_tensor, suppress_tokens)
|
suppress_token_mask = torch.isin(vocab_tensor, self.suppress_tokens)
|
||||||
scores = torch.where(suppress_token_mask, -float("inf"), scores)
|
scores = torch.where(suppress_token_mask, -float("inf"), scores)
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
@@ -2268,16 +2286,22 @@ class BarkEosPrioritizerLogitsProcessor(LogitsProcessor):
|
|||||||
</Tip>
|
</Tip>
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
eos_token_id (`Union[int, List[int]]`):
|
eos_token_id (`Union[int, List[int], torch.Tensor]`):
|
||||||
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
|
The id(s) of the *end-of-sequence* token.
|
||||||
min_eos_p (`float`, *optional*):
|
min_eos_p (`float`, *optional*):
|
||||||
Minimum end of speech threshold.
|
Minimum end of speech threshold.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, eos_token_id: Union[int, List[int]], min_eos_p: float):
|
def __init__(self, eos_token_id: Union[int, List[int], torch.Tensor], min_eos_p: float):
|
||||||
if isinstance(eos_token_id, int):
|
if not isinstance(eos_token_id, torch.Tensor):
|
||||||
eos_token_id = [eos_token_id]
|
if isinstance(eos_token_id, int):
|
||||||
|
eos_token_id = [eos_token_id]
|
||||||
|
eos_token_id = torch.tensor(eos_token_id)
|
||||||
self.eos_token_id = eos_token_id
|
self.eos_token_id = eos_token_id
|
||||||
|
|
||||||
|
if torch.is_floating_point(eos_token_id) or (eos_token_id < 0).any():
|
||||||
|
raise ValueError(f"`eos_token_id` has to be a list of positive integers, but is {eos_token_id}")
|
||||||
|
|
||||||
if min_eos_p is not None and min_eos_p <= 0:
|
if min_eos_p is not None and min_eos_p <= 0:
|
||||||
raise ValueError(f"`min_eos_p` has to be a positive float, but is {min_eos_p}")
|
raise ValueError(f"`min_eos_p` has to be a positive float, but is {min_eos_p}")
|
||||||
self.min_eos_p = min_eos_p
|
self.min_eos_p = min_eos_p
|
||||||
@@ -2285,6 +2309,7 @@ class BarkEosPrioritizerLogitsProcessor(LogitsProcessor):
|
|||||||
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
||||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
scores_processed = scores
|
scores_processed = scores
|
||||||
|
self.eos_token_id = self.eos_token_id.to(scores.device)
|
||||||
if self.min_eos_p:
|
if self.min_eos_p:
|
||||||
probs = torch.nn.functional.softmax(scores.float(), dim=-1)
|
probs = torch.nn.functional.softmax(scores.float(), dim=-1)
|
||||||
# create scores full of -inf except for the eos_token_id
|
# create scores full of -inf except for the eos_token_id
|
||||||
|
|||||||
@@ -470,29 +470,32 @@ class EosTokenCriteria(StoppingCriteria):
|
|||||||
By default, it uses the `model.generation_config.eos_token_id`.
|
By default, it uses the `model.generation_config.eos_token_id`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
eos_token_id (`Union[int, List[int]]`):
|
eos_token_id (`Union[int, List[int], torch.Tensor]`):
|
||||||
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
|
The id(s) of the *end-of-sequence* token.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, eos_token_id: Union[int, List[int]]):
|
def __init__(self, eos_token_id: Union[int, List[int], torch.Tensor]):
|
||||||
if isinstance(eos_token_id, int):
|
if not isinstance(eos_token_id, torch.Tensor):
|
||||||
eos_token_id = [eos_token_id]
|
if isinstance(eos_token_id, int):
|
||||||
self.eos_token_id = torch.tensor(eos_token_id)
|
eos_token_id = [eos_token_id]
|
||||||
|
eos_token_id = torch.tensor(eos_token_id)
|
||||||
|
self.eos_token_id = eos_token_id
|
||||||
|
|
||||||
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
|
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
|
||||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
|
||||||
|
self.eos_token_id = self.eos_token_id.to(input_ids.device)
|
||||||
if input_ids.device.type == "mps":
|
if input_ids.device.type == "mps":
|
||||||
# https://github.com/pytorch/pytorch/issues/77764#issuecomment-2067838075
|
# https://github.com/pytorch/pytorch/issues/77764#issuecomment-2067838075
|
||||||
is_done = (
|
is_done = (
|
||||||
input_ids[:, -1]
|
input_ids[:, -1]
|
||||||
.tile(self.eos_token_id.shape[0], 1)
|
.tile(self.eos_token_id.shape[0], 1)
|
||||||
.eq(self.eos_token_id.unsqueeze(1).to(input_ids.device))
|
.eq(self.eos_token_id.unsqueeze(1))
|
||||||
.sum(dim=0)
|
.sum(dim=0)
|
||||||
.bool()
|
.bool()
|
||||||
.squeeze()
|
.squeeze()
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
is_done = torch.isin(input_ids[:, -1], self.eos_token_id.to(input_ids.device))
|
is_done = torch.isin(input_ids[:, -1], self.eos_token_id)
|
||||||
return is_done
|
return is_done
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -353,7 +353,7 @@ class GenerationMixin:
|
|||||||
def _prepare_model_inputs(
|
def _prepare_model_inputs(
|
||||||
self,
|
self,
|
||||||
inputs: Optional[torch.Tensor] = None,
|
inputs: Optional[torch.Tensor] = None,
|
||||||
bos_token_id: Optional[int] = None,
|
bos_token_id: Optional[torch.Tensor] = None,
|
||||||
model_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
model_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
||||||
) -> Tuple[torch.Tensor, Optional[str], Dict[str, torch.Tensor]]:
|
) -> Tuple[torch.Tensor, Optional[str], Dict[str, torch.Tensor]]:
|
||||||
"""
|
"""
|
||||||
@@ -417,7 +417,7 @@ class GenerationMixin:
|
|||||||
def _maybe_initialize_input_ids_for_generation(
|
def _maybe_initialize_input_ids_for_generation(
|
||||||
self,
|
self,
|
||||||
inputs: Optional[torch.Tensor] = None,
|
inputs: Optional[torch.Tensor] = None,
|
||||||
bos_token_id: Optional[int] = None,
|
bos_token_id: Optional[torch.Tensor] = None,
|
||||||
model_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
model_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
||||||
) -> torch.LongTensor:
|
) -> torch.LongTensor:
|
||||||
"""Initializes input ids for generation, if necessary."""
|
"""Initializes input ids for generation, if necessary."""
|
||||||
@@ -449,20 +449,37 @@ class GenerationMixin:
|
|||||||
def _prepare_attention_mask_for_generation(
|
def _prepare_attention_mask_for_generation(
|
||||||
self,
|
self,
|
||||||
inputs: torch.Tensor,
|
inputs: torch.Tensor,
|
||||||
pad_token_id: Optional[int],
|
pad_token_id: Optional[torch.Tensor],
|
||||||
eos_token_id: Optional[Union[int, List[int]]],
|
eos_token_id: Optional[torch.Tensor],
|
||||||
) -> torch.LongTensor:
|
) -> torch.LongTensor:
|
||||||
is_input_ids = len(inputs.shape) == 2 and inputs.dtype in [torch.int, torch.long]
|
# No information for attention mask inference -> return default attention mask
|
||||||
is_pad_token_in_inputs = (pad_token_id is not None) and (pad_token_id in inputs)
|
default_attention_mask = torch.ones(inputs.shape[:2], dtype=torch.long, device=inputs.device)
|
||||||
if isinstance(eos_token_id, int):
|
if pad_token_id is None:
|
||||||
eos_token_id = [eos_token_id]
|
return default_attention_mask
|
||||||
is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (pad_token_id not in eos_token_id)
|
|
||||||
|
|
||||||
# Check if input is input_ids and padded -> only then is attention_mask defined
|
is_input_ids = len(inputs.shape) == 2 and inputs.dtype in [torch.int, torch.long]
|
||||||
if is_input_ids and is_pad_token_in_inputs and is_pad_token_not_equal_to_eos_token_id:
|
if not is_input_ids:
|
||||||
return inputs.ne(pad_token_id).long()
|
return default_attention_mask
|
||||||
else:
|
|
||||||
return torch.ones(inputs.shape[:2], dtype=torch.long, device=inputs.device)
|
# Otherwise we have may have information -> try to infer the attention mask
|
||||||
|
if inputs.device.type == "mps":
|
||||||
|
# mps does not support torch.isin (https://github.com/pytorch/pytorch/issues/77764)
|
||||||
|
raise ValueError(
|
||||||
|
"Can't infer missing attention mask on `mps` device. Please provide an `attention_mask` or use a different device."
|
||||||
|
)
|
||||||
|
|
||||||
|
is_pad_token_in_inputs = (pad_token_id is not None) and (
|
||||||
|
torch.isin(elements=inputs, test_elements=pad_token_id).any()
|
||||||
|
)
|
||||||
|
is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or ~(
|
||||||
|
torch.isin(elements=eos_token_id, test_elements=pad_token_id).any()
|
||||||
|
)
|
||||||
|
can_infer_attention_mask = is_pad_token_in_inputs * is_pad_token_not_equal_to_eos_token_id
|
||||||
|
attention_mask_from_padding = inputs.ne(pad_token_id).long()
|
||||||
|
attention_mask = (
|
||||||
|
attention_mask_from_padding * can_infer_attention_mask + default_attention_mask * ~can_infer_attention_mask
|
||||||
|
)
|
||||||
|
return attention_mask
|
||||||
|
|
||||||
def _prepare_encoder_decoder_kwargs_for_generation(
|
def _prepare_encoder_decoder_kwargs_for_generation(
|
||||||
self,
|
self,
|
||||||
@@ -510,8 +527,7 @@ class GenerationMixin:
|
|||||||
batch_size: int,
|
batch_size: int,
|
||||||
model_input_name: str,
|
model_input_name: str,
|
||||||
model_kwargs: Dict[str, torch.Tensor],
|
model_kwargs: Dict[str, torch.Tensor],
|
||||||
decoder_start_token_id: Union[int, List[int]] = None,
|
decoder_start_token_id: torch.Tensor,
|
||||||
bos_token_id: int = None,
|
|
||||||
device: torch.device = None,
|
device: torch.device = None,
|
||||||
) -> Tuple[torch.LongTensor, Dict[str, torch.Tensor]]:
|
) -> Tuple[torch.LongTensor, Dict[str, torch.Tensor]]:
|
||||||
"""Prepares `decoder_input_ids` for generation with encoder-decoder models"""
|
"""Prepares `decoder_input_ids` for generation with encoder-decoder models"""
|
||||||
@@ -524,25 +540,24 @@ class GenerationMixin:
|
|||||||
else:
|
else:
|
||||||
decoder_input_ids = None
|
decoder_input_ids = None
|
||||||
|
|
||||||
# 2. Encoder-decoder models expect the `decoder_input_ids` to start with a special token. Let's ensure that.
|
# 2. `decoder_start_token_id` must have shape (batch_size, 1)
|
||||||
decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id)
|
|
||||||
if device is None:
|
if device is None:
|
||||||
device = self.device
|
device = self.device
|
||||||
if isinstance(decoder_start_token_id, list):
|
if decoder_start_token_id.ndim == 1:
|
||||||
if len(decoder_start_token_id) != batch_size:
|
if decoder_start_token_id.shape[0] != batch_size:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"`decoder_start_token_id` expcted to have length {batch_size} but got {len(decoder_start_token_id)}"
|
f"`decoder_start_token_id` expected to have length {batch_size} but got {decoder_start_token_id.shape[0]}"
|
||||||
)
|
)
|
||||||
decoder_input_ids_start = torch.tensor(decoder_start_token_id, dtype=torch.long, device=device)
|
decoder_start_token_id = decoder_start_token_id.view(-1, 1)
|
||||||
decoder_input_ids_start = decoder_input_ids_start.view(-1, 1)
|
|
||||||
else:
|
else:
|
||||||
decoder_input_ids_start = (
|
decoder_start_token_id = (
|
||||||
torch.ones((batch_size, 1), dtype=torch.long, device=device) * decoder_start_token_id
|
torch.ones((batch_size, 1), dtype=torch.long, device=device) * decoder_start_token_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 3. Encoder-decoder models expect the `decoder_input_ids` to start with a special token. Let's ensure that.
|
||||||
# no user input -> use decoder_start_token_id as decoder_input_ids
|
# no user input -> use decoder_start_token_id as decoder_input_ids
|
||||||
if decoder_input_ids is None:
|
if decoder_input_ids is None:
|
||||||
decoder_input_ids = decoder_input_ids_start
|
decoder_input_ids = decoder_start_token_id
|
||||||
# exception: Donut checkpoints have task-specific decoder starts and don't expect a BOS token
|
# exception: Donut checkpoints have task-specific decoder starts and don't expect a BOS token
|
||||||
elif self.config.model_type == "vision-encoder-decoder" and "donut" in self.name_or_path.lower():
|
elif self.config.model_type == "vision-encoder-decoder" and "donut" in self.name_or_path.lower():
|
||||||
pass
|
pass
|
||||||
@@ -550,14 +565,8 @@ class GenerationMixin:
|
|||||||
pass
|
pass
|
||||||
# user input but doesn't start with decoder_start_token_id -> prepend decoder_start_token_id (and adjust
|
# user input but doesn't start with decoder_start_token_id -> prepend decoder_start_token_id (and adjust
|
||||||
# decoder_attention_mask if provided)
|
# decoder_attention_mask if provided)
|
||||||
elif (
|
elif (decoder_input_ids[:, 0] != decoder_start_token_id[:, 0]).all().item():
|
||||||
isinstance(decoder_start_token_id, int)
|
decoder_input_ids = torch.cat([decoder_start_token_id, decoder_input_ids], dim=-1)
|
||||||
and (decoder_input_ids[:, 0] != decoder_start_token_id).all().item()
|
|
||||||
) or (
|
|
||||||
isinstance(decoder_start_token_id, torch.Tensor)
|
|
||||||
and (decoder_input_ids[:, 0] != decoder_start_token_id[:, 0]).all().item()
|
|
||||||
):
|
|
||||||
decoder_input_ids = torch.cat([decoder_input_ids_start, decoder_input_ids], dim=-1)
|
|
||||||
if "decoder_attention_mask" in model_kwargs:
|
if "decoder_attention_mask" in model_kwargs:
|
||||||
decoder_attention_mask = model_kwargs["decoder_attention_mask"]
|
decoder_attention_mask = model_kwargs["decoder_attention_mask"]
|
||||||
decoder_attention_mask = torch.cat(
|
decoder_attention_mask = torch.cat(
|
||||||
@@ -568,24 +577,6 @@ class GenerationMixin:
|
|||||||
|
|
||||||
return decoder_input_ids, model_kwargs
|
return decoder_input_ids, model_kwargs
|
||||||
|
|
||||||
def _get_decoder_start_token_id(
|
|
||||||
self, decoder_start_token_id: Union[int, List[int]] = None, bos_token_id: int = None
|
|
||||||
) -> int:
|
|
||||||
decoder_start_token_id = (
|
|
||||||
decoder_start_token_id
|
|
||||||
if decoder_start_token_id is not None
|
|
||||||
else self.generation_config.decoder_start_token_id
|
|
||||||
)
|
|
||||||
bos_token_id = bos_token_id if bos_token_id is not None else self.generation_config.bos_token_id
|
|
||||||
|
|
||||||
if decoder_start_token_id is not None:
|
|
||||||
return decoder_start_token_id
|
|
||||||
elif bos_token_id is not None:
|
|
||||||
return bos_token_id
|
|
||||||
raise ValueError(
|
|
||||||
"`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation."
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _expand_inputs_for_generation(
|
def _expand_inputs_for_generation(
|
||||||
expand_size: int = 1,
|
expand_size: int = 1,
|
||||||
@@ -729,6 +720,8 @@ class GenerationMixin:
|
|||||||
if generation_config.num_beams > 1:
|
if generation_config.num_beams > 1:
|
||||||
if isinstance(generation_config.eos_token_id, list):
|
if isinstance(generation_config.eos_token_id, list):
|
||||||
min_tokens_to_keep = len(generation_config.eos_token_id) + 1
|
min_tokens_to_keep = len(generation_config.eos_token_id) + 1
|
||||||
|
elif isinstance(generation_config.eos_token_id, torch.Tensor):
|
||||||
|
min_tokens_to_keep = generation_config.eos_token_id.shape[0] + 1
|
||||||
else:
|
else:
|
||||||
min_tokens_to_keep = 2
|
min_tokens_to_keep = 2
|
||||||
else:
|
else:
|
||||||
@@ -1346,6 +1339,61 @@ class GenerationMixin:
|
|||||||
self._static_cache.reset() # reset the cache for a new generation
|
self._static_cache.reset() # reset the cache for a new generation
|
||||||
return self._static_cache
|
return self._static_cache
|
||||||
|
|
||||||
|
def _prepare_special_tokens(
|
||||||
|
self, generation_config: GenerationConfig, kwargs_has_attention_mask: Optional[bool] = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Prepares the special tokens for generation, overwriting the generation config with their processed versions
|
||||||
|
converted to tensor.
|
||||||
|
|
||||||
|
Note that `generation_config` is changed in place and stops being serializable after this method is called.
|
||||||
|
That is no problem if called within `generate` (`generation_config` is a local copy that doesn't leave the
|
||||||
|
function). However, if called outside `generate`, consider creating a copy of `generation_config` first.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Convert special tokens to tensors (if they exist)
|
||||||
|
def _tensor_or_none(token):
|
||||||
|
if token is None or isinstance(token, torch.Tensor):
|
||||||
|
return token
|
||||||
|
return torch.tensor(token, device=self.device, dtype=torch.long)
|
||||||
|
|
||||||
|
bos_token_id = _tensor_or_none(generation_config.bos_token_id)
|
||||||
|
eos_token_id = _tensor_or_none(generation_config.eos_token_id)
|
||||||
|
pad_token_id = _tensor_or_none(generation_config.pad_token_id)
|
||||||
|
decoder_start_token_id = _tensor_or_none(generation_config.decoder_start_token_id)
|
||||||
|
decoder_start_token_id = decoder_start_token_id if decoder_start_token_id is not None else bos_token_id
|
||||||
|
|
||||||
|
# We can have more than one eos token. Always treat it as a 1D tensor (when it exists).
|
||||||
|
if eos_token_id is not None and eos_token_id.ndim == 0:
|
||||||
|
eos_token_id = eos_token_id.unsqueeze(0)
|
||||||
|
|
||||||
|
# Set pad token if unset (and there are conditions to do so)
|
||||||
|
if pad_token_id is None and eos_token_id is not None:
|
||||||
|
if kwargs_has_attention_mask is not None and not kwargs_has_attention_mask:
|
||||||
|
logger.warning(
|
||||||
|
"The attention mask and the pad token id were not set. As a consequence, you may observe "
|
||||||
|
"unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
|
||||||
|
)
|
||||||
|
pad_token_id = eos_token_id[0]
|
||||||
|
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{pad_token_id} for open-end generation.")
|
||||||
|
|
||||||
|
# Sanity checks/warnings
|
||||||
|
if self.config.is_encoder_decoder and decoder_start_token_id is None:
|
||||||
|
raise ValueError(
|
||||||
|
"`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation."
|
||||||
|
)
|
||||||
|
if eos_token_id is not None and (torch.is_floating_point(eos_token_id) or (eos_token_id < 0).any()):
|
||||||
|
logger.warning(
|
||||||
|
f"`eos_token_id` should consist of positive integers, but is {eos_token_id}. Your generation will not "
|
||||||
|
"stop until the maximum length is reached. Depending on other flags, it may even crash."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update generation config with the updated special tokens tensors
|
||||||
|
generation_config.bos_token_id = bos_token_id
|
||||||
|
generation_config.eos_token_id = eos_token_id
|
||||||
|
generation_config.pad_token_id = pad_token_id
|
||||||
|
generation_config.decoder_start_token_id = decoder_start_token_id
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
@@ -1460,28 +1508,32 @@ class GenerationMixin:
|
|||||||
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
||||||
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
||||||
|
|
||||||
if generation_config.pad_token_id is None and generation_config.eos_token_id is not None:
|
accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys())
|
||||||
if model_kwargs.get("attention_mask", None) is None:
|
requires_attention_mask = "encoder_outputs" not in model_kwargs
|
||||||
logger.warning(
|
kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
|
||||||
"The attention mask and the pad token id were not set. As a consequence, you may observe "
|
self._prepare_special_tokens(generation_config, kwargs_has_attention_mask)
|
||||||
"unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
|
|
||||||
)
|
|
||||||
eos_token_id = generation_config.eos_token_id
|
|
||||||
if isinstance(eos_token_id, list):
|
|
||||||
eos_token_id = eos_token_id[0]
|
|
||||||
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
|
|
||||||
generation_config.pad_token_id = eos_token_id
|
|
||||||
|
|
||||||
# 3. Define model inputs
|
# 3. Define model inputs
|
||||||
# inputs_tensor has to be defined
|
|
||||||
# model_input_name is defined if model-specific keyword input is passed
|
|
||||||
# otherwise model_input_name is None
|
|
||||||
# all model-specific keyword inputs are removed from `model_kwargs`
|
|
||||||
inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
|
inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
|
||||||
inputs, generation_config.bos_token_id, model_kwargs
|
inputs, generation_config.bos_token_id, model_kwargs
|
||||||
)
|
)
|
||||||
batch_size = inputs_tensor.shape[0]
|
batch_size = inputs_tensor.shape[0]
|
||||||
|
|
||||||
|
# decoder-only models must use left-padding for batched generation.
|
||||||
|
if not self.config.is_encoder_decoder and not is_torchdynamo_compiling():
|
||||||
|
# If `input_ids` was given, check if the last id in any sequence is `pad_token_id`
|
||||||
|
# Note: If using, `inputs_embeds` this check does not work, because we want to be more hands-off.
|
||||||
|
if (
|
||||||
|
generation_config.pad_token_id is not None
|
||||||
|
and batch_size > 1
|
||||||
|
and len(inputs_tensor.shape) == 2
|
||||||
|
and torch.sum(inputs_tensor[:, -1] == generation_config.pad_token_id) > 0
|
||||||
|
):
|
||||||
|
logger.warning(
|
||||||
|
"A decoder-only architecture is being used, but right-padding was detected! For correct "
|
||||||
|
"generation results, please set `padding_side='left'` when initializing the tokenizer."
|
||||||
|
)
|
||||||
|
|
||||||
# 4. Define other model kwargs
|
# 4. Define other model kwargs
|
||||||
# decoder-only models with inputs_embeds forwarding must use caching (otherwise we can't detect whether we are
|
# decoder-only models with inputs_embeds forwarding must use caching (otherwise we can't detect whether we are
|
||||||
# generating the first new token or not, and we only want to use the embeddings for the first new token)
|
# generating the first new token or not, and we only want to use the embeddings for the first new token)
|
||||||
@@ -1490,31 +1542,13 @@ class GenerationMixin:
|
|||||||
else:
|
else:
|
||||||
model_kwargs["use_cache"] = generation_config.use_cache
|
model_kwargs["use_cache"] = generation_config.use_cache
|
||||||
|
|
||||||
accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys())
|
if not kwargs_has_attention_mask and requires_attention_mask and accepts_attention_mask:
|
||||||
requires_attention_mask = "encoder_outputs" not in model_kwargs
|
|
||||||
|
|
||||||
if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask:
|
|
||||||
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
|
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
|
||||||
inputs_tensor, generation_config.pad_token_id, generation_config.eos_token_id
|
inputs_tensor, generation_config.pad_token_id, generation_config.eos_token_id
|
||||||
)
|
)
|
||||||
|
|
||||||
# decoder-only models should use left-padding for generation
|
|
||||||
if not self.config.is_encoder_decoder:
|
|
||||||
# If `input_ids` was given, check if the last id in any sequence is `pad_token_id`
|
|
||||||
# Note: If using, `inputs_embeds` this check does not work, because we want to be more hands-off.
|
|
||||||
if (
|
|
||||||
generation_config.pad_token_id is not None
|
|
||||||
and len(inputs_tensor.shape) == 2
|
|
||||||
and torch.sum(inputs_tensor[:, -1] == generation_config.pad_token_id) > 0
|
|
||||||
):
|
|
||||||
logger.warning(
|
|
||||||
"A decoder-only architecture is being used, but right-padding was detected! For correct "
|
|
||||||
"generation results, please set `padding_side='left'` when initializing the tokenizer."
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs:
|
if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs:
|
||||||
# if model is encoder decoder encoder_outputs are created
|
# if model is encoder decoder encoder_outputs are created and added to `model_kwargs`
|
||||||
# and added to `model_kwargs`
|
|
||||||
model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(
|
model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(
|
||||||
inputs_tensor, model_kwargs, model_input_name, generation_config
|
inputs_tensor, model_kwargs, model_input_name, generation_config
|
||||||
)
|
)
|
||||||
@@ -1526,7 +1560,6 @@ class GenerationMixin:
|
|||||||
model_input_name=model_input_name,
|
model_input_name=model_input_name,
|
||||||
model_kwargs=model_kwargs,
|
model_kwargs=model_kwargs,
|
||||||
decoder_start_token_id=generation_config.decoder_start_token_id,
|
decoder_start_token_id=generation_config.decoder_start_token_id,
|
||||||
bos_token_id=generation_config.bos_token_id,
|
|
||||||
device=inputs_tensor.device,
|
device=inputs_tensor.device,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -2631,9 +2664,6 @@ class GenerationMixin:
|
|||||||
return_dict_in_generate = generation_config.return_dict_in_generate
|
return_dict_in_generate = generation_config.return_dict_in_generate
|
||||||
sequential = generation_config.low_memory
|
sequential = generation_config.low_memory
|
||||||
|
|
||||||
if isinstance(eos_token_id, int):
|
|
||||||
eos_token_id = [eos_token_id]
|
|
||||||
|
|
||||||
batch_size = len(beam_scorer._beam_hyps)
|
batch_size = len(beam_scorer._beam_hyps)
|
||||||
num_beams = beam_scorer.num_beams
|
num_beams = beam_scorer.num_beams
|
||||||
|
|
||||||
@@ -2757,7 +2787,7 @@ class GenerationMixin:
|
|||||||
next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)
|
next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)
|
||||||
|
|
||||||
# Sample 1 + len(eos_token_id) next tokens for each beam so we have at least 1 non eos token per beam.
|
# Sample 1 + len(eos_token_id) next tokens for each beam so we have at least 1 non eos token per beam.
|
||||||
n_eos_tokens = len(eos_token_id) if eos_token_id else 0
|
n_eos_tokens = eos_token_id.shape[0] if eos_token_id is not None else 0
|
||||||
next_token_scores, next_tokens = torch.topk(
|
next_token_scores, next_tokens = torch.topk(
|
||||||
next_token_scores, max(2, 1 + n_eos_tokens) * num_beams, dim=1, largest=True, sorted=True
|
next_token_scores, max(2, 1 + n_eos_tokens) * num_beams, dim=1, largest=True, sorted=True
|
||||||
)
|
)
|
||||||
@@ -2901,9 +2931,6 @@ class GenerationMixin:
|
|||||||
output_logits = generation_config.output_logits
|
output_logits = generation_config.output_logits
|
||||||
return_dict_in_generate = generation_config.return_dict_in_generate
|
return_dict_in_generate = generation_config.return_dict_in_generate
|
||||||
|
|
||||||
if isinstance(eos_token_id, int):
|
|
||||||
eos_token_id = [eos_token_id]
|
|
||||||
|
|
||||||
batch_size = len(beam_scorer._beam_hyps)
|
batch_size = len(beam_scorer._beam_hyps)
|
||||||
num_beams = beam_scorer.num_beams
|
num_beams = beam_scorer.num_beams
|
||||||
|
|
||||||
@@ -3124,9 +3151,6 @@ class GenerationMixin:
|
|||||||
output_logits = generation_config.output_logits
|
output_logits = generation_config.output_logits
|
||||||
return_dict_in_generate = generation_config.return_dict_in_generate
|
return_dict_in_generate = generation_config.return_dict_in_generate
|
||||||
|
|
||||||
if isinstance(eos_token_id, int):
|
|
||||||
eos_token_id = [eos_token_id]
|
|
||||||
|
|
||||||
num_beams = beam_scorer.num_beams
|
num_beams = beam_scorer.num_beams
|
||||||
num_beam_groups = beam_scorer.num_beam_groups
|
num_beam_groups = beam_scorer.num_beam_groups
|
||||||
num_sub_beams = num_beams // num_beam_groups
|
num_sub_beams = num_beams // num_beam_groups
|
||||||
@@ -3229,7 +3253,7 @@ class GenerationMixin:
|
|||||||
next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size)
|
next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size)
|
||||||
|
|
||||||
# Sample 1 + len(eos_token_id) next tokens for each beam so we have at least 1 non eos token per beam.
|
# Sample 1 + len(eos_token_id) next tokens for each beam so we have at least 1 non eos token per beam.
|
||||||
n_eos_tokens = len(eos_token_id) if eos_token_id else 0
|
n_eos_tokens = eos_token_id.shape[0] if eos_token_id is not None else 0
|
||||||
next_token_scores, next_tokens = torch.topk(
|
next_token_scores, next_tokens = torch.topk(
|
||||||
next_token_scores, max(2, 1 + n_eos_tokens) * group_size, dim=1, largest=True, sorted=True
|
next_token_scores, max(2, 1 + n_eos_tokens) * group_size, dim=1, largest=True, sorted=True
|
||||||
)
|
)
|
||||||
@@ -3409,9 +3433,6 @@ class GenerationMixin:
|
|||||||
output_logits = generation_config.output_logits
|
output_logits = generation_config.output_logits
|
||||||
return_dict_in_generate = generation_config.return_dict_in_generate
|
return_dict_in_generate = generation_config.return_dict_in_generate
|
||||||
|
|
||||||
if isinstance(eos_token_id, int):
|
|
||||||
eos_token_id = [eos_token_id]
|
|
||||||
|
|
||||||
batch_size = len(constrained_beam_scorer._beam_hyps)
|
batch_size = len(constrained_beam_scorer._beam_hyps)
|
||||||
num_beams = constrained_beam_scorer.num_beams
|
num_beams = constrained_beam_scorer.num_beams
|
||||||
|
|
||||||
@@ -3501,7 +3522,7 @@ class GenerationMixin:
|
|||||||
next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)
|
next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)
|
||||||
|
|
||||||
# Sample 1 + len(eos_token_id) next tokens for each beam so we have at least 1 non eos token per beam.
|
# Sample 1 + len(eos_token_id) next tokens for each beam so we have at least 1 non eos token per beam.
|
||||||
n_eos_tokens = len(eos_token_id) if eos_token_id else 0
|
n_eos_tokens = eos_token_id.shape[0] if eos_token_id is not None else 0
|
||||||
next_token_scores, next_tokens = torch.topk(
|
next_token_scores, next_tokens = torch.topk(
|
||||||
next_token_scores, max(2, 1 + n_eos_tokens) * num_beams, dim=1, largest=True, sorted=True
|
next_token_scores, max(2, 1 + n_eos_tokens) * num_beams, dim=1, largest=True, sorted=True
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ import inspect
|
|||||||
import math
|
import math
|
||||||
import random
|
import random
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@@ -2587,6 +2587,24 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
|
|||||||
break
|
break
|
||||||
return torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * bos_token_id
|
return torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * bos_token_id
|
||||||
|
|
||||||
|
def _get_decoder_start_token_id(
|
||||||
|
self, decoder_start_token_id: Union[int, List[int]] = None, bos_token_id: int = None
|
||||||
|
) -> int:
|
||||||
|
decoder_start_token_id = (
|
||||||
|
decoder_start_token_id
|
||||||
|
if decoder_start_token_id is not None
|
||||||
|
else self.generation_config.decoder_start_token_id
|
||||||
|
)
|
||||||
|
bos_token_id = bos_token_id if bos_token_id is not None else self.generation_config.bos_token_id
|
||||||
|
|
||||||
|
if decoder_start_token_id is not None:
|
||||||
|
return decoder_start_token_id
|
||||||
|
elif bos_token_id is not None:
|
||||||
|
return bos_token_id
|
||||||
|
raise ValueError(
|
||||||
|
"`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation."
|
||||||
|
)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ import inspect
|
|||||||
import math
|
import math
|
||||||
import random
|
import random
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@@ -2452,6 +2452,25 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel):
|
|||||||
param.requires_grad = False
|
param.requires_grad = False
|
||||||
self.text_encoder._requires_grad = False
|
self.text_encoder._requires_grad = False
|
||||||
|
|
||||||
|
# Copied from transformers.models.musicgen.modeling_musicgen.MusicgenForConditionalGeneration._get_decoder_start_token_id
|
||||||
|
def _get_decoder_start_token_id(
|
||||||
|
self, decoder_start_token_id: Union[int, List[int]] = None, bos_token_id: int = None
|
||||||
|
) -> int:
|
||||||
|
decoder_start_token_id = (
|
||||||
|
decoder_start_token_id
|
||||||
|
if decoder_start_token_id is not None
|
||||||
|
else self.generation_config.decoder_start_token_id
|
||||||
|
)
|
||||||
|
bos_token_id = bos_token_id if bos_token_id is not None else self.generation_config.bos_token_id
|
||||||
|
|
||||||
|
if decoder_start_token_id is not None:
|
||||||
|
return decoder_start_token_id
|
||||||
|
elif bos_token_id is not None:
|
||||||
|
return bos_token_id
|
||||||
|
raise ValueError(
|
||||||
|
"`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation."
|
||||||
|
)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -1458,6 +1458,9 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
|||||||
generation_config = copy.deepcopy(generation_config)
|
generation_config = copy.deepcopy(generation_config)
|
||||||
model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs
|
model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs
|
||||||
|
|
||||||
|
kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
|
||||||
|
self._prepare_special_tokens(generation_config, kwargs_has_attention_mask)
|
||||||
|
|
||||||
# set default parameters
|
# set default parameters
|
||||||
n_docs = n_docs if n_docs is not None else self.config.n_docs
|
n_docs = n_docs if n_docs is not None else self.config.n_docs
|
||||||
|
|
||||||
|
|||||||
@@ -14,6 +14,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import copy
|
||||||
import inspect
|
import inspect
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
@@ -168,7 +169,9 @@ class GenerationTesterMixin:
|
|||||||
encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.repeat_interleave(
|
encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.repeat_interleave(
|
||||||
num_interleave, dim=0
|
num_interleave, dim=0
|
||||||
)
|
)
|
||||||
input_ids = torch.zeros_like(input_ids[:, :1]) + model._get_decoder_start_token_id()
|
generation_config = copy.deepcopy(model.generation_config)
|
||||||
|
model._prepare_special_tokens(generation_config)
|
||||||
|
input_ids = torch.zeros_like(input_ids[:, :1]) + generation_config.decoder_start_token_id
|
||||||
attention_mask = None
|
attention_mask = None
|
||||||
return encoder_outputs, input_ids, attention_mask
|
return encoder_outputs, input_ids, attention_mask
|
||||||
|
|
||||||
|
|||||||
@@ -414,9 +414,11 @@ class SeamlessM4TModelWithSpeechInputTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.repeat_interleave(
|
encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.repeat_interleave(
|
||||||
num_interleave, dim=0
|
num_interleave, dim=0
|
||||||
)
|
)
|
||||||
|
generation_config = copy.deepcopy(model.generation_config)
|
||||||
|
model._prepare_special_tokens(generation_config)
|
||||||
input_ids = (
|
input_ids = (
|
||||||
torch.zeros(input_ids.shape[:2], dtype=torch.int64, layout=input_ids.layout, device=input_ids.device)
|
torch.zeros(input_ids.shape[:2], dtype=torch.int64, layout=input_ids.layout, device=input_ids.device)
|
||||||
+ model._get_decoder_start_token_id()
|
+ generation_config.decoder_start_token_id
|
||||||
)
|
)
|
||||||
attention_mask = None
|
attention_mask = None
|
||||||
return encoder_outputs, input_ids, attention_mask
|
return encoder_outputs, input_ids, attention_mask
|
||||||
|
|||||||
@@ -430,9 +430,11 @@ class SeamlessM4Tv2ModelWithSpeechInputTest(ModelTesterMixin, unittest.TestCase)
|
|||||||
encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.repeat_interleave(
|
encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.repeat_interleave(
|
||||||
num_interleave, dim=0
|
num_interleave, dim=0
|
||||||
)
|
)
|
||||||
|
generation_config = copy.deepcopy(model.generation_config)
|
||||||
|
model._prepare_special_tokens(generation_config)
|
||||||
input_ids = (
|
input_ids = (
|
||||||
torch.zeros(input_ids.shape[:2], dtype=torch.int64, layout=input_ids.layout, device=input_ids.device)
|
torch.zeros(input_ids.shape[:2], dtype=torch.int64, layout=input_ids.layout, device=input_ids.device)
|
||||||
+ model._get_decoder_start_token_id()
|
+ generation_config.decoder_start_token_id
|
||||||
)
|
)
|
||||||
attention_mask = None
|
attention_mask = None
|
||||||
return encoder_outputs, input_ids, attention_mask
|
return encoder_outputs, input_ids, attention_mask
|
||||||
|
|||||||
@@ -645,7 +645,9 @@ class Speech2TextModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTest
|
|||||||
num_interleave, dim=0
|
num_interleave, dim=0
|
||||||
)
|
)
|
||||||
input_ids = input_ids[:, :, 0]
|
input_ids = input_ids[:, :, 0]
|
||||||
input_ids = torch.zeros_like(input_ids[:, :1], dtype=torch.long) + model._get_decoder_start_token_id()
|
generation_config = copy.deepcopy(model.generation_config)
|
||||||
|
model._prepare_special_tokens(generation_config)
|
||||||
|
input_ids = torch.zeros_like(input_ids[:, :1]) + generation_config.decoder_start_token_id
|
||||||
attention_mask = None
|
attention_mask = None
|
||||||
return encoder_outputs, input_ids, attention_mask
|
return encoder_outputs, input_ids, attention_mask
|
||||||
|
|
||||||
|
|||||||
@@ -833,10 +833,10 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
|||||||
encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.repeat_interleave(
|
encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.repeat_interleave(
|
||||||
num_interleave, dim=0
|
num_interleave, dim=0
|
||||||
)
|
)
|
||||||
|
generation_config = copy.deepcopy(model.generation_config)
|
||||||
|
model._prepare_special_tokens(generation_config)
|
||||||
input_ids = input_ids[:, :, 0]
|
input_ids = input_ids[:, :, 0]
|
||||||
input_ids = torch.zeros_like(input_ids[:, :1], dtype=torch.long) + torch.tensor(
|
input_ids = torch.zeros_like(input_ids[:, :1], dtype=torch.long) + generation_config.decoder_start_token_id
|
||||||
[model._get_decoder_start_token_id()], device=input_ids.device
|
|
||||||
)
|
|
||||||
attention_mask = None
|
attention_mask = None
|
||||||
return encoder_outputs, input_ids, attention_mask
|
return encoder_outputs, input_ids, attention_mask
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user