Pass device in Logits Processor's init (#29804)

* add device in logits processor

* remove device when not needed

* codestyle

* tests

* forgot `melody` version

* Update src/transformers/models/whisper/generation_whisper.py

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* codestyle

* updates

---------

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
Raushan Turganbay
2024-06-04 10:19:19 +05:00
committed by GitHub
parent c73ee1333d
commit 83238eeebc
7 changed files with 119 additions and 52 deletions

View File

@@ -110,6 +110,8 @@ class MinLengthLogitsProcessor(LogitsProcessor):
The minimum length below which the score of `eos_token_id` is set to `-float("Inf")`. The minimum length below which the score of `eos_token_id` is set to `-float("Inf")`.
eos_token_id (`Union[int, List[int], torch.Tensor]`): eos_token_id (`Union[int, List[int], torch.Tensor]`):
The id(s) of the *end-of-sequence* token. The id(s) of the *end-of-sequence* token.
device (`str`, *optional*, defaults to `"cpu"`):
The device to allocate the tensors.
Examples: Examples:
@@ -137,14 +139,14 @@ class MinLengthLogitsProcessor(LogitsProcessor):
``` ```
""" """
def __init__(self, min_length: int, eos_token_id: Union[int, List[int], torch.Tensor]): def __init__(self, min_length: int, eos_token_id: Union[int, List[int], torch.Tensor], device: str = "cpu"):
if not isinstance(min_length, int) or min_length < 0: if not isinstance(min_length, int) or min_length < 0:
raise ValueError(f"`min_length` has to be a non-negative integer, but is {min_length}") raise ValueError(f"`min_length` has to be a non-negative integer, but is {min_length}")
if not isinstance(eos_token_id, torch.Tensor): if not isinstance(eos_token_id, torch.Tensor):
if isinstance(eos_token_id, int): if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id] eos_token_id = [eos_token_id]
eos_token_id = torch.tensor(eos_token_id) eos_token_id = torch.tensor(eos_token_id, device=device)
self.min_length = min_length self.min_length = min_length
self.eos_token_id = eos_token_id self.eos_token_id = eos_token_id
@@ -152,7 +154,6 @@ class MinLengthLogitsProcessor(LogitsProcessor):
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
vocab_tensor = torch.arange(scores.shape[-1], device=scores.device) vocab_tensor = torch.arange(scores.shape[-1], device=scores.device)
self.eos_token_id = self.eos_token_id.to(scores.device)
eos_token_mask = torch.isin(vocab_tensor, self.eos_token_id) eos_token_mask = torch.isin(vocab_tensor, self.eos_token_id)
scores_processed = scores.clone() scores_processed = scores.clone()
if input_ids.shape[-1] < self.min_length: if input_ids.shape[-1] < self.min_length:
@@ -173,6 +174,8 @@ class MinNewTokensLengthLogitsProcessor(LogitsProcessor):
The minimum *new* tokens length below which the score of `eos_token_id` is set to `-float("Inf")`. The minimum *new* tokens length below which the score of `eos_token_id` is set to `-float("Inf")`.
eos_token_id (`Union[int, List[int], torch.Tensor]`): eos_token_id (`Union[int, List[int], torch.Tensor]`):
The id(s) of the *end-of-sequence* token. The id(s) of the *end-of-sequence* token.
device (`str`, *optional*, defaults to `"cpu"`):
The device to allocate the tensors.
Examples: Examples:
@@ -196,7 +199,11 @@ class MinNewTokensLengthLogitsProcessor(LogitsProcessor):
""" """
def __init__( def __init__(
self, prompt_length_to_skip: int, min_new_tokens: int, eos_token_id: Union[int, List[int], torch.Tensor] self,
prompt_length_to_skip: int,
min_new_tokens: int,
eos_token_id: Union[int, List[int], torch.Tensor],
device: str = "cpu",
): ):
for arg_name, arg_value in [ for arg_name, arg_value in [
("prompt_length_to_skip", prompt_length_to_skip), ("prompt_length_to_skip", prompt_length_to_skip),
@@ -208,7 +215,7 @@ class MinNewTokensLengthLogitsProcessor(LogitsProcessor):
if not isinstance(eos_token_id, torch.Tensor): if not isinstance(eos_token_id, torch.Tensor):
if isinstance(eos_token_id, int): if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id] eos_token_id = [eos_token_id]
eos_token_id = torch.tensor(eos_token_id) eos_token_id = torch.tensor(eos_token_id, device=device)
self.prompt_length_to_skip = prompt_length_to_skip self.prompt_length_to_skip = prompt_length_to_skip
self.min_new_tokens = min_new_tokens self.min_new_tokens = min_new_tokens
@@ -219,7 +226,6 @@ class MinNewTokensLengthLogitsProcessor(LogitsProcessor):
new_tokens_length = input_ids.shape[-1] - self.prompt_length_to_skip new_tokens_length = input_ids.shape[-1] - self.prompt_length_to_skip
scores_processed = scores.clone() scores_processed = scores.clone()
vocab_tensor = torch.arange(scores.shape[-1], device=scores.device) vocab_tensor = torch.arange(scores.shape[-1], device=scores.device)
self.eos_token_id = self.eos_token_id.to(scores.device)
eos_token_mask = torch.isin(vocab_tensor, self.eos_token_id) eos_token_mask = torch.isin(vocab_tensor, self.eos_token_id)
if new_tokens_length < self.min_new_tokens: if new_tokens_length < self.min_new_tokens:
scores_processed = torch.where(eos_token_mask, -math.inf, scores) scores_processed = torch.where(eos_token_mask, -math.inf, scores)
@@ -779,6 +785,8 @@ class EtaLogitsWarper(LogitsWarper):
Specifies the minimum number of tokens that must be kept for generation, regardless of their probabilities. Specifies the minimum number of tokens that must be kept for generation, regardless of their probabilities.
For example, if `min_tokens_to_keep` is set to 1, at least one token will always be kept for generation, For example, if `min_tokens_to_keep` is set to 1, at least one token will always be kept for generation,
even if all tokens have probabilities below the cutoff `eta`. even if all tokens have probabilities below the cutoff `eta`.
device (`str`, *optional*, defaults to `"cpu"`):
The device to allocate the tensors.
Examples: Examples:
```python ```python
@@ -806,7 +814,9 @@ class EtaLogitsWarper(LogitsWarper):
``` ```
""" """
def __init__(self, epsilon: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1): def __init__(
self, epsilon: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1, device: str = "cpu"
):
epsilon = float(epsilon) epsilon = float(epsilon)
if epsilon <= 0 or epsilon >= 1: if epsilon <= 0 or epsilon >= 1:
raise ValueError(f"`eta_cutoff` has to be a float > 0 and < 1, but is {epsilon}") raise ValueError(f"`eta_cutoff` has to be a float > 0 and < 1, but is {epsilon}")
@@ -817,13 +827,12 @@ class EtaLogitsWarper(LogitsWarper):
f"`min_tokens_to_keep` has to be a strictly positive integer, but is {min_tokens_to_keep}" f"`min_tokens_to_keep` has to be a strictly positive integer, but is {min_tokens_to_keep}"
) )
self.epsilon = torch.tensor(epsilon) self.epsilon = torch.tensor(epsilon, device=device)
self.filter_value = filter_value self.filter_value = filter_value
self.min_tokens_to_keep = min_tokens_to_keep self.min_tokens_to_keep = min_tokens_to_keep
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
# Calculate the adaptive cutoff
probabilities = scores.softmax(dim=-1) probabilities = scores.softmax(dim=-1)
entropy = torch.distributions.Categorical(logits=scores).entropy() entropy = torch.distributions.Categorical(logits=scores).entropy()
eta = torch.min(self.epsilon, torch.sqrt(self.epsilon) * torch.exp(-entropy))[..., None] eta = torch.min(self.epsilon, torch.sqrt(self.epsilon) * torch.exp(-entropy))[..., None]
@@ -1530,6 +1539,8 @@ class ForcedEOSTokenLogitsProcessor(LogitsProcessor):
The maximum length of the sequence to be generated. The maximum length of the sequence to be generated.
eos_token_id (`Union[int, List[int], torch.Tensor]`): eos_token_id (`Union[int, List[int], torch.Tensor]`):
The id(s) of the *end-of-sequence* token. The id(s) of the *end-of-sequence* token.
device (`str`, *optional*, defaults to `"cpu"`):
The device to allocate the tensors.
Examples: Examples:
@@ -1553,13 +1564,13 @@ class ForcedEOSTokenLogitsProcessor(LogitsProcessor):
``` ```
""" """
def __init__(self, max_length: int, eos_token_id: Union[int, List[int], torch.Tensor]): def __init__(self, max_length: int, eos_token_id: Union[int, List[int], torch.Tensor], device: str = "cpu"):
self.max_length = max_length self.max_length = max_length
if not isinstance(eos_token_id, torch.Tensor): if not isinstance(eos_token_id, torch.Tensor):
if isinstance(eos_token_id, int): if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id] eos_token_id = [eos_token_id]
eos_token_id = torch.tensor(eos_token_id) eos_token_id = torch.tensor(eos_token_id, device=device)
self.eos_token_id = eos_token_id self.eos_token_id = eos_token_id
if torch.is_floating_point(eos_token_id) or (eos_token_id < 0).any(): if torch.is_floating_point(eos_token_id) or (eos_token_id < 0).any():
@@ -1568,7 +1579,6 @@ class ForcedEOSTokenLogitsProcessor(LogitsProcessor):
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
cur_len = input_ids.shape[-1] cur_len = input_ids.shape[-1]
self.eos_token_id = self.eos_token_id.to(scores.device)
scores_processed = scores scores_processed = scores
if cur_len == self.max_length - 1: if cur_len == self.max_length - 1:
scores_processed = torch.full_like(scores, -math.inf) scores_processed = torch.full_like(scores, -math.inf)
@@ -1770,8 +1780,8 @@ class SuppressTokensAtBeginLogitsProcessor(LogitsProcessor):
``` ```
""" """
def __init__(self, begin_suppress_tokens, begin_index): def __init__(self, begin_suppress_tokens, begin_index, device: str = "cpu"):
self.begin_suppress_tokens = torch.tensor(list(begin_suppress_tokens)) self.begin_suppress_tokens = torch.tensor(list(begin_suppress_tokens), device=device)
self.begin_index = begin_index self.begin_index = begin_index
def set_begin_index(self, begin_index): def set_begin_index(self, begin_index):
@@ -1780,7 +1790,6 @@ class SuppressTokensAtBeginLogitsProcessor(LogitsProcessor):
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
vocab_tensor = torch.arange(scores.shape[-1], device=scores.device) vocab_tensor = torch.arange(scores.shape[-1], device=scores.device)
self.begin_suppress_tokens = self.begin_suppress_tokens.to(scores.device)
suppress_token_mask = torch.isin(vocab_tensor, self.begin_suppress_tokens) suppress_token_mask = torch.isin(vocab_tensor, self.begin_suppress_tokens)
scores_processed = scores scores_processed = scores
if input_ids.shape[-1] == self.begin_index: if input_ids.shape[-1] == self.begin_index:
@@ -1818,13 +1827,12 @@ class SuppressTokensLogitsProcessor(LogitsProcessor):
``` ```
""" """
def __init__(self, suppress_tokens): def __init__(self, suppress_tokens, device: str = "cpu"):
self.suppress_tokens = torch.tensor(list(suppress_tokens)) self.suppress_tokens = torch.tensor(list(suppress_tokens), device=device)
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
vocab_tensor = torch.arange(scores.shape[-1], device=scores.device) vocab_tensor = torch.arange(scores.shape[-1], device=scores.device)
self.suppress_tokens = self.suppress_tokens.to(scores.device)
suppress_token_mask = torch.isin(vocab_tensor, self.suppress_tokens) suppress_token_mask = torch.isin(vocab_tensor, self.suppress_tokens)
scores = torch.where(suppress_token_mask, -float("inf"), scores) scores = torch.where(suppress_token_mask, -float("inf"), scores)
return scores return scores
@@ -1915,7 +1923,10 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor):
""" """
def __init__( def __init__(
self, generate_config, begin_index: Optional[int] = None, _detect_timestamp_from_logprob: Optional[bool] = None self,
generate_config,
begin_index: Optional[int] = None,
_detect_timestamp_from_logprob: Optional[bool] = None,
): # support for the kwargs ): # support for the kwargs
self.no_timestamps_token_id = generate_config.no_timestamps_token_id self.no_timestamps_token_id = generate_config.no_timestamps_token_id
self.timestamp_begin = generate_config.no_timestamps_token_id + 1 self.timestamp_begin = generate_config.no_timestamps_token_id + 1
@@ -2292,11 +2303,11 @@ class BarkEosPrioritizerLogitsProcessor(LogitsProcessor):
Minimum end of speech threshold. Minimum end of speech threshold.
""" """
def __init__(self, eos_token_id: Union[int, List[int], torch.Tensor], min_eos_p: float): def __init__(self, eos_token_id: Union[int, List[int], torch.Tensor], min_eos_p: float, device: str = "cpu"):
if not isinstance(eos_token_id, torch.Tensor): if not isinstance(eos_token_id, torch.Tensor):
if isinstance(eos_token_id, int): if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id] eos_token_id = [eos_token_id]
eos_token_id = torch.tensor(eos_token_id) eos_token_id = torch.tensor(eos_token_id, device=device)
self.eos_token_id = eos_token_id self.eos_token_id = eos_token_id
if torch.is_floating_point(eos_token_id) or (eos_token_id < 0).any(): if torch.is_floating_point(eos_token_id) or (eos_token_id < 0).any():
@@ -2309,7 +2320,6 @@ class BarkEosPrioritizerLogitsProcessor(LogitsProcessor):
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
scores_processed = scores scores_processed = scores
self.eos_token_id = self.eos_token_id.to(scores.device)
if self.min_eos_p: if self.min_eos_p:
probs = torch.nn.functional.softmax(scores.float(), dim=-1) probs = torch.nn.functional.softmax(scores.float(), dim=-1)
# create scores full of -inf except for the eos_token_id # create scores full of -inf except for the eos_token_id

View File

@@ -723,6 +723,7 @@ class GenerationMixin:
def _get_logits_warper( def _get_logits_warper(
self, self,
generation_config: GenerationConfig, generation_config: GenerationConfig,
device: str,
) -> LogitsProcessorList: ) -> LogitsProcessorList:
""" """
This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsWarper`] instances This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsWarper`] instances
@@ -765,7 +766,9 @@ class GenerationMixin:
) )
if generation_config.eta_cutoff is not None and 0.0 < generation_config.eta_cutoff < 1.0: if generation_config.eta_cutoff is not None and 0.0 < generation_config.eta_cutoff < 1.0:
warpers.append( warpers.append(
EtaLogitsWarper(epsilon=generation_config.eta_cutoff, min_tokens_to_keep=min_tokens_to_keep) EtaLogitsWarper(
epsilon=generation_config.eta_cutoff, min_tokens_to_keep=min_tokens_to_keep, device=device
)
) )
# `LogitNormalization` should always be the last logit processor, when present # `LogitNormalization` should always be the last logit processor, when present
if generation_config.renormalize_logits is True: if generation_config.renormalize_logits is True:
@@ -818,7 +821,8 @@ class GenerationMixin:
): ):
processors.append( processors.append(
EncoderRepetitionPenaltyLogitsProcessor( EncoderRepetitionPenaltyLogitsProcessor(
penalty=generation_config.encoder_repetition_penalty, encoder_input_ids=encoder_input_ids penalty=generation_config.encoder_repetition_penalty,
encoder_input_ids=encoder_input_ids,
) )
) )
if generation_config.repetition_penalty is not None and generation_config.repetition_penalty != 1.0: if generation_config.repetition_penalty is not None and generation_config.repetition_penalty != 1.0:
@@ -830,18 +834,30 @@ class GenerationMixin:
and generation_config.encoder_no_repeat_ngram_size > 0 and generation_config.encoder_no_repeat_ngram_size > 0
): ):
processors.append( processors.append(
EncoderNoRepeatNGramLogitsProcessor(generation_config.encoder_no_repeat_ngram_size, encoder_input_ids) EncoderNoRepeatNGramLogitsProcessor(
generation_config.encoder_no_repeat_ngram_size,
encoder_input_ids,
)
) )
if generation_config.bad_words_ids is not None: if generation_config.bad_words_ids is not None:
processors.append( processors.append(
NoBadWordsLogitsProcessor(generation_config.bad_words_ids, generation_config.eos_token_id) NoBadWordsLogitsProcessor(
generation_config.bad_words_ids,
generation_config.eos_token_id,
)
) )
if ( if (
generation_config.min_length is not None generation_config.min_length is not None
and generation_config.eos_token_id is not None and generation_config.eos_token_id is not None
and generation_config.min_length > 0 and generation_config.min_length > 0
): ):
processors.append(MinLengthLogitsProcessor(generation_config.min_length, generation_config.eos_token_id)) processors.append(
MinLengthLogitsProcessor(
generation_config.min_length,
generation_config.eos_token_id,
device=device,
)
)
if ( if (
generation_config.min_new_tokens is not None generation_config.min_new_tokens is not None
and generation_config.eos_token_id is not None and generation_config.eos_token_id is not None
@@ -849,20 +865,32 @@ class GenerationMixin:
): ):
processors.append( processors.append(
MinNewTokensLengthLogitsProcessor( MinNewTokensLengthLogitsProcessor(
input_ids_seq_length, generation_config.min_new_tokens, generation_config.eos_token_id input_ids_seq_length,
generation_config.min_new_tokens,
generation_config.eos_token_id,
device=device,
) )
) )
if prefix_allowed_tokens_fn is not None: if prefix_allowed_tokens_fn is not None:
processors.append( processors.append(
PrefixConstrainedLogitsProcessor( PrefixConstrainedLogitsProcessor(
prefix_allowed_tokens_fn, generation_config.num_beams // generation_config.num_beam_groups prefix_allowed_tokens_fn,
generation_config.num_beams // generation_config.num_beam_groups,
) )
) )
if generation_config.forced_bos_token_id is not None: if generation_config.forced_bos_token_id is not None:
processors.append(ForcedBOSTokenLogitsProcessor(generation_config.forced_bos_token_id)) processors.append(
ForcedBOSTokenLogitsProcessor(
generation_config.forced_bos_token_id,
)
)
if generation_config.forced_eos_token_id is not None: if generation_config.forced_eos_token_id is not None:
processors.append( processors.append(
ForcedEOSTokenLogitsProcessor(generation_config.max_length, generation_config.forced_eos_token_id) ForcedEOSTokenLogitsProcessor(
generation_config.max_length,
generation_config.forced_eos_token_id,
device=device,
)
) )
if generation_config.remove_invalid_values is True: if generation_config.remove_invalid_values is True:
processors.append(InfNanRemoveLogitsProcessor()) processors.append(InfNanRemoveLogitsProcessor())
@@ -875,7 +903,12 @@ class GenerationMixin:
) )
) )
if generation_config.suppress_tokens is not None: if generation_config.suppress_tokens is not None:
processors.append(SuppressTokensLogitsProcessor(generation_config.suppress_tokens)) processors.append(
SuppressTokensLogitsProcessor(
generation_config.suppress_tokens,
device=device,
)
)
if generation_config.begin_suppress_tokens is not None: if generation_config.begin_suppress_tokens is not None:
begin_index = input_ids_seq_length begin_index = input_ids_seq_length
begin_index = ( begin_index = (
@@ -887,7 +920,11 @@ class GenerationMixin:
# generation starts after the last token that is forced # generation starts after the last token that is forced
begin_index += generation_config.forced_decoder_ids[-1][0] begin_index += generation_config.forced_decoder_ids[-1][0]
processors.append( processors.append(
SuppressTokensAtBeginLogitsProcessor(generation_config.begin_suppress_tokens, begin_index) SuppressTokensAtBeginLogitsProcessor(
generation_config.begin_suppress_tokens,
begin_index,
device=device,
)
) )
if generation_config.forced_decoder_ids is not None: if generation_config.forced_decoder_ids is not None:
# TODO(Sanchit): deprecate in v4.40 by removing this logic # TODO(Sanchit): deprecate in v4.40 by removing this logic
@@ -1779,7 +1816,12 @@ class GenerationMixin:
# 12. prepare logits warper (if `do_sample` is `True`) # 12. prepare logits warper (if `do_sample` is `True`)
prepared_logits_warper = ( prepared_logits_warper = (
self._get_logits_warper(generation_config) if generation_config.do_sample else None self._get_logits_warper(
generation_config,
device=input_ids.device,
)
if generation_config.do_sample
else None
) )
# 13. run assisted generate # 13. run assisted generate
@@ -1812,7 +1854,9 @@ class GenerationMixin:
elif generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH): elif generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH):
# 11. prepare logits warper # 11. prepare logits warper
prepared_logits_warper = ( prepared_logits_warper = (
self._get_logits_warper(generation_config) if generation_config.do_sample else None self._get_logits_warper(generation_config, device=input_ids.device)
if generation_config.do_sample
else None
) )
# 12. expand input_ids with `num_return_sequences` additional sequences per batch # 12. expand input_ids with `num_return_sequences` additional sequences per batch
@@ -1838,7 +1882,9 @@ class GenerationMixin:
elif generation_mode in (GenerationMode.BEAM_SAMPLE, GenerationMode.BEAM_SEARCH): elif generation_mode in (GenerationMode.BEAM_SAMPLE, GenerationMode.BEAM_SEARCH):
# 11. prepare logits warper # 11. prepare logits warper
prepared_logits_warper = ( prepared_logits_warper = (
self._get_logits_warper(generation_config) if generation_config.do_sample else None self._get_logits_warper(generation_config, device=input_ids.device)
if generation_config.do_sample
else None
) )
# 12. prepare beam search scorer # 12. prepare beam search scorer

View File

@@ -1729,6 +1729,7 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
encoder_input_ids=input_ids, encoder_input_ids=input_ids,
prefix_allowed_tokens_fn=None, prefix_allowed_tokens_fn=None,
logits_processor=logits_processor, logits_processor=logits_processor,
device=input_ids.device,
) )
# 10. prepare stopping criteria # 10. prepare stopping criteria
@@ -1756,7 +1757,7 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
elif is_sample_gen_mode: elif is_sample_gen_mode:
# 11. prepare logits warper # 11. prepare logits warper
logits_warper = self._get_logits_warper(generation_config) logits_warper = self._get_logits_warper(generation_config, device=input_ids.device)
# expand input_ids with `num_return_sequences` additional sequences per batch # expand input_ids with `num_return_sequences` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids, model_kwargs = self._expand_inputs_for_generation(
@@ -2822,6 +2823,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
encoder_input_ids=inputs_tensor, encoder_input_ids=inputs_tensor,
prefix_allowed_tokens_fn=None, prefix_allowed_tokens_fn=None,
logits_processor=logits_processor, logits_processor=logits_processor,
device=input_ids.device,
) )
# 10. prepare stopping criteria # 10. prepare stopping criteria
@@ -2849,7 +2851,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
elif is_sample_gen_mode: elif is_sample_gen_mode:
# 11. prepare logits warper # 11. prepare logits warper
logits_warper = self._get_logits_warper(generation_config) logits_warper = self._get_logits_warper(generation_config, device=input_ids.device)
# expand input_ids with `num_return_sequences` additional sequences per batch # expand input_ids with `num_return_sequences` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids, model_kwargs = self._expand_inputs_for_generation(

View File

@@ -1666,6 +1666,7 @@ class MusicgenMelodyForCausalLM(MusicgenMelodyPreTrainedModel):
encoder_input_ids=input_ids, encoder_input_ids=input_ids,
prefix_allowed_tokens_fn=None, prefix_allowed_tokens_fn=None,
logits_processor=logits_processor, logits_processor=logits_processor,
device=input_ids.device,
) )
# 10. prepare stopping criteria # 10. prepare stopping criteria
@@ -1693,7 +1694,7 @@ class MusicgenMelodyForCausalLM(MusicgenMelodyPreTrainedModel):
elif is_sample_gen_mode: elif is_sample_gen_mode:
# 11. prepare logits warper # 11. prepare logits warper
logits_warper = self._get_logits_warper(generation_config) logits_warper = self._get_logits_warper(generation_config, device=input_ids.device)
# expand input_ids with `num_return_sequences` additional sequences per batch # expand input_ids with `num_return_sequences` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids, model_kwargs = self._expand_inputs_for_generation(
@@ -2681,6 +2682,7 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel):
encoder_input_ids=inputs_tensor, encoder_input_ids=inputs_tensor,
prefix_allowed_tokens_fn=None, prefix_allowed_tokens_fn=None,
logits_processor=logits_processor, logits_processor=logits_processor,
device=input_ids.device,
) )
# 10. prepare stopping criteria # 10. prepare stopping criteria
@@ -2708,7 +2710,7 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel):
elif is_sample_gen_mode: elif is_sample_gen_mode:
# 11. prepare logits warper # 11. prepare logits warper
logits_warper = self._get_logits_warper(generation_config) logits_warper = self._get_logits_warper(generation_config, device=input_ids.device)
# expand input_ids with `num_return_sequences` additional sequences per batch # expand input_ids with `num_return_sequences` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids, model_kwargs = self._expand_inputs_for_generation(

View File

@@ -1538,6 +1538,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
encoder_input_ids=context_input_ids, encoder_input_ids=context_input_ids,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
logits_processor=logits_processor, logits_processor=logits_processor,
device=input_ids.device,
) )
prepared_stopping_criteria = self._get_stopping_criteria( prepared_stopping_criteria = self._get_stopping_criteria(

View File

@@ -548,13 +548,15 @@ class WhisperGenerationMixin:
self._check_decoder_input_ids(kwargs=kwargs) self._check_decoder_input_ids(kwargs=kwargs)
# 3. Retrieve logits processors # 3. Retrieve logits processors
device = kwargs["encoder_outputs"][0].device if "encoder_outputs" in kwargs else input_features.device
begin_index = init_tokens.shape[1] begin_index = init_tokens.shape[1]
logits_processor = self._retrieve_logit_processors( logits_processor = self._retrieve_logit_processors(
generation_config=generation_config, generation_config=generation_config,
logits_processor=logits_processor, logits_processor=logits_processor,
begin_index=begin_index, # begin index is index of first generated decoder token begin_index=begin_index, # begin index is index of first generated decoder token
is_shortform=is_shortform, is_shortform=is_shortform,
num_beams=generation_config.num_beams, num_beams=kwargs.get("num_beams", 1),
device=device,
) )
# 5. If we're in shortform mode, simple generate the whole input at once and return the output # 5. If we're in shortform mode, simple generate the whole input at once and return the output
@@ -1400,7 +1402,9 @@ class WhisperGenerationMixin:
return max_frames, seek return max_frames, seek
def _retrieve_logit_processors(self, generation_config, logits_processor, begin_index, is_shortform, num_beams): def _retrieve_logit_processors(
self, generation_config, logits_processor, begin_index, is_shortform, num_beams, device
):
if generation_config.return_timestamps is True: if generation_config.return_timestamps is True:
timestamp_processor = WhisperTimeStampLogitsProcessor(generation_config, begin_index=begin_index) timestamp_processor = WhisperTimeStampLogitsProcessor(generation_config, begin_index=begin_index)
logits_processor = ( logits_processor = (
@@ -1408,7 +1412,7 @@ class WhisperGenerationMixin:
) )
if generation_config.suppress_tokens is not None: if generation_config.suppress_tokens is not None:
suppress_tokens_processor = SuppressTokensLogitsProcessor(generation_config.suppress_tokens) suppress_tokens_processor = SuppressTokensLogitsProcessor(generation_config.suppress_tokens, device=device)
logits_processor = ( logits_processor = (
[suppress_tokens_processor] [suppress_tokens_processor]
if logits_processor is None if logits_processor is None
@@ -1418,7 +1422,7 @@ class WhisperGenerationMixin:
if generation_config.begin_suppress_tokens is not None: if generation_config.begin_suppress_tokens is not None:
begin_suppress_processor = SuppressTokensAtBeginLogitsProcessor( begin_suppress_processor = SuppressTokensAtBeginLogitsProcessor(
generation_config.begin_suppress_tokens, begin_index=begin_index generation_config.begin_suppress_tokens, begin_index=begin_index, device=device
) )
logits_processor = ( logits_processor = (
[begin_suppress_processor] [begin_suppress_processor]

View File

@@ -69,7 +69,7 @@ class LogitsProcessorTest(unittest.TestCase):
batch_size = 4 batch_size = 4
eos_token_id = 0 eos_token_id = 0
min_dist_processor = MinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id) min_dist_processor = MinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id, device=torch_device)
# check that min length is applied at length 5 # check that min length is applied at length 5
input_ids = ids_tensor((batch_size, 5), vocab_size=20) input_ids = ids_tensor((batch_size, 5), vocab_size=20)
@@ -91,7 +91,7 @@ class LogitsProcessorTest(unittest.TestCase):
# check that first input is skipped (min new length applying) # check that first input is skipped (min new length applying)
input_ids = ids_tensor((batch_size, 5), vocab_size=20) input_ids = ids_tensor((batch_size, 5), vocab_size=20)
new_min_dist_processor = MinNewTokensLengthLogitsProcessor( new_min_dist_processor = MinNewTokensLengthLogitsProcessor(
prompt_length_to_skip=input_ids.shape[-1], min_new_tokens=3, eos_token_id=eos_token_id prompt_length_to_skip=input_ids.shape[-1], min_new_tokens=3, eos_token_id=eos_token_id, device=torch_device
) )
expected_eos_scores_before_min_length = batch_size * [-float("inf")] expected_eos_scores_before_min_length = batch_size * [-float("inf")]
@@ -450,7 +450,7 @@ class LogitsProcessorTest(unittest.TestCase):
torch.tensor([[0.0, 0.1, 0.8, 0.1], [0.01, 0.04, 0.9, 0.05]], device=torch_device, dtype=torch.float) torch.tensor([[0.0, 0.1, 0.8, 0.1], [0.01, 0.04, 0.9, 0.05]], device=torch_device, dtype=torch.float)
) )
eta_warp = EtaLogitsWarper(0.0625) eta_warp = EtaLogitsWarper(0.0625, device=torch_device)
filtered_dist = torch.exp(eta_warp(input_ids, dist)) filtered_dist = torch.exp(eta_warp(input_ids, dist))
# dist should be filtered to only keep values with proba >= min(0.0625, sqrt(0.0625) * e^-H(p)) # dist should be filtered to only keep values with proba >= min(0.0625, sqrt(0.0625) * e^-H(p))
@@ -474,7 +474,7 @@ class LogitsProcessorTest(unittest.TestCase):
ramp_logits[1] = ramp_logits[1] * 100.0 ramp_logits[1] = ramp_logits[1] * 100.0
# make sure at least 2 tokens are kept # make sure at least 2 tokens are kept
eta_warp = EtaLogitsWarper(0.1, min_tokens_to_keep=2, filter_value=0.0) eta_warp = EtaLogitsWarper(0.1, min_tokens_to_keep=2, filter_value=0.0, device=torch_device)
filtered_dist = eta_warp(input_ids, ramp_logits) filtered_dist = eta_warp(input_ids, ramp_logits)
# first batch should keep 2 tokens, second batch would keep only 1, but due to `min_tokens_to_keep=2` keeps 2. # first batch should keep 2 tokens, second batch would keep only 1, but due to `min_tokens_to_keep=2` keeps 2.
@@ -640,7 +640,7 @@ class LogitsProcessorTest(unittest.TestCase):
scores_comp = scores.clone() scores_comp = scores.clone()
# instantiate all dist processors # instantiate all dist processors
min_dist_proc = MinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id) min_dist_proc = MinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id, device=torch_device)
temp_dist_warp = TemperatureLogitsWarper(temperature=0.5) temp_dist_warp = TemperatureLogitsWarper(temperature=0.5)
rep_penalty_proc = RepetitionPenaltyLogitsProcessor(penalty=2.0) rep_penalty_proc = RepetitionPenaltyLogitsProcessor(penalty=2.0)
top_k_warp = TopKLogitsWarper(3) top_k_warp = TopKLogitsWarper(3)
@@ -767,7 +767,9 @@ class LogitsProcessorTest(unittest.TestCase):
eos_token_id = 0 eos_token_id = 0
max_length = 5 max_length = 5
logits_processor = ForcedEOSTokenLogitsProcessor(max_length=max_length, eos_token_id=eos_token_id) logits_processor = ForcedEOSTokenLogitsProcessor(
max_length=max_length, eos_token_id=eos_token_id, device=torch_device
)
# check that all scores are -inf except the eos_token_id when max_length-1 is reached # check that all scores are -inf except the eos_token_id when max_length-1 is reached
input_ids = ids_tensor((batch_size, 4), vocab_size=20) input_ids = ids_tensor((batch_size, 4), vocab_size=20)
@@ -927,7 +929,7 @@ class LogitsProcessorTest(unittest.TestCase):
scores = self._get_uniform_logits(2, 4) scores = self._get_uniform_logits(2, 4)
scores[0][eos_token_id] = -6 ## less than log(min_eos_p) scores[0][eos_token_id] = -6 ## less than log(min_eos_p)
esp = BarkEosPrioritizerLogitsProcessor(eos_token_id=eos_token_id, min_eos_p=min_eos_p) esp = BarkEosPrioritizerLogitsProcessor(eos_token_id=eos_token_id, min_eos_p=min_eos_p, device=torch_device)
actual_scores = esp(input_ids, scores) actual_scores = esp(input_ids, scores)
expected_scores_list = [ expected_scores_list = [
scores[0].tolist(), scores[0].tolist(),
@@ -943,7 +945,7 @@ class LogitsProcessorTest(unittest.TestCase):
scores = self._get_uniform_logits(2, 4) scores = self._get_uniform_logits(2, 4)
scores[0][eos_token_id] = -6 ## less than log(min_eos_p) scores[0][eos_token_id] = -6 ## less than log(min_eos_p)
esp = BarkEosPrioritizerLogitsProcessor(eos_token_id=eos_token_id, min_eos_p=min_eos_p) esp = BarkEosPrioritizerLogitsProcessor(eos_token_id=eos_token_id, min_eos_p=min_eos_p, device=torch_device)
actual_scores = esp(input_ids, scores) actual_scores = esp(input_ids, scores)
expected_scores_list = [ expected_scores_list = [
scores[0].tolist(), scores[0].tolist(),