Pass device in Logits Processor's init (#29804)
* add device in logits processor * remove device when not needed * codestyle * tests * forgot `melody` version * Update src/transformers/models/whisper/generation_whisper.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * codestyle * updates --------- Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
committed by
GitHub
parent
c73ee1333d
commit
83238eeebc
@@ -110,6 +110,8 @@ class MinLengthLogitsProcessor(LogitsProcessor):
|
||||
The minimum length below which the score of `eos_token_id` is set to `-float("Inf")`.
|
||||
eos_token_id (`Union[int, List[int], torch.Tensor]`):
|
||||
The id(s) of the *end-of-sequence* token.
|
||||
device (`str`, *optional*, defaults to `"cpu"`):
|
||||
The device to allocate the tensors.
|
||||
|
||||
Examples:
|
||||
|
||||
@@ -137,14 +139,14 @@ class MinLengthLogitsProcessor(LogitsProcessor):
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, min_length: int, eos_token_id: Union[int, List[int], torch.Tensor]):
|
||||
def __init__(self, min_length: int, eos_token_id: Union[int, List[int], torch.Tensor], device: str = "cpu"):
|
||||
if not isinstance(min_length, int) or min_length < 0:
|
||||
raise ValueError(f"`min_length` has to be a non-negative integer, but is {min_length}")
|
||||
|
||||
if not isinstance(eos_token_id, torch.Tensor):
|
||||
if isinstance(eos_token_id, int):
|
||||
eos_token_id = [eos_token_id]
|
||||
eos_token_id = torch.tensor(eos_token_id)
|
||||
eos_token_id = torch.tensor(eos_token_id, device=device)
|
||||
|
||||
self.min_length = min_length
|
||||
self.eos_token_id = eos_token_id
|
||||
@@ -152,7 +154,6 @@ class MinLengthLogitsProcessor(LogitsProcessor):
|
||||
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
vocab_tensor = torch.arange(scores.shape[-1], device=scores.device)
|
||||
self.eos_token_id = self.eos_token_id.to(scores.device)
|
||||
eos_token_mask = torch.isin(vocab_tensor, self.eos_token_id)
|
||||
scores_processed = scores.clone()
|
||||
if input_ids.shape[-1] < self.min_length:
|
||||
@@ -173,6 +174,8 @@ class MinNewTokensLengthLogitsProcessor(LogitsProcessor):
|
||||
The minimum *new* tokens length below which the score of `eos_token_id` is set to `-float("Inf")`.
|
||||
eos_token_id (`Union[int, List[int], torch.Tensor]`):
|
||||
The id(s) of the *end-of-sequence* token.
|
||||
device (`str`, *optional*, defaults to `"cpu"`):
|
||||
The device to allocate the tensors.
|
||||
|
||||
Examples:
|
||||
|
||||
@@ -196,7 +199,11 @@ class MinNewTokensLengthLogitsProcessor(LogitsProcessor):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, prompt_length_to_skip: int, min_new_tokens: int, eos_token_id: Union[int, List[int], torch.Tensor]
|
||||
self,
|
||||
prompt_length_to_skip: int,
|
||||
min_new_tokens: int,
|
||||
eos_token_id: Union[int, List[int], torch.Tensor],
|
||||
device: str = "cpu",
|
||||
):
|
||||
for arg_name, arg_value in [
|
||||
("prompt_length_to_skip", prompt_length_to_skip),
|
||||
@@ -208,7 +215,7 @@ class MinNewTokensLengthLogitsProcessor(LogitsProcessor):
|
||||
if not isinstance(eos_token_id, torch.Tensor):
|
||||
if isinstance(eos_token_id, int):
|
||||
eos_token_id = [eos_token_id]
|
||||
eos_token_id = torch.tensor(eos_token_id)
|
||||
eos_token_id = torch.tensor(eos_token_id, device=device)
|
||||
|
||||
self.prompt_length_to_skip = prompt_length_to_skip
|
||||
self.min_new_tokens = min_new_tokens
|
||||
@@ -219,7 +226,6 @@ class MinNewTokensLengthLogitsProcessor(LogitsProcessor):
|
||||
new_tokens_length = input_ids.shape[-1] - self.prompt_length_to_skip
|
||||
scores_processed = scores.clone()
|
||||
vocab_tensor = torch.arange(scores.shape[-1], device=scores.device)
|
||||
self.eos_token_id = self.eos_token_id.to(scores.device)
|
||||
eos_token_mask = torch.isin(vocab_tensor, self.eos_token_id)
|
||||
if new_tokens_length < self.min_new_tokens:
|
||||
scores_processed = torch.where(eos_token_mask, -math.inf, scores)
|
||||
@@ -779,6 +785,8 @@ class EtaLogitsWarper(LogitsWarper):
|
||||
Specifies the minimum number of tokens that must be kept for generation, regardless of their probabilities.
|
||||
For example, if `min_tokens_to_keep` is set to 1, at least one token will always be kept for generation,
|
||||
even if all tokens have probabilities below the cutoff `eta`.
|
||||
device (`str`, *optional*, defaults to `"cpu"`):
|
||||
The device to allocate the tensors.
|
||||
|
||||
Examples:
|
||||
```python
|
||||
@@ -806,7 +814,9 @@ class EtaLogitsWarper(LogitsWarper):
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, epsilon: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
|
||||
def __init__(
|
||||
self, epsilon: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1, device: str = "cpu"
|
||||
):
|
||||
epsilon = float(epsilon)
|
||||
if epsilon <= 0 or epsilon >= 1:
|
||||
raise ValueError(f"`eta_cutoff` has to be a float > 0 and < 1, but is {epsilon}")
|
||||
@@ -817,13 +827,12 @@ class EtaLogitsWarper(LogitsWarper):
|
||||
f"`min_tokens_to_keep` has to be a strictly positive integer, but is {min_tokens_to_keep}"
|
||||
)
|
||||
|
||||
self.epsilon = torch.tensor(epsilon)
|
||||
self.epsilon = torch.tensor(epsilon, device=device)
|
||||
self.filter_value = filter_value
|
||||
self.min_tokens_to_keep = min_tokens_to_keep
|
||||
|
||||
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
# Calculate the adaptive cutoff
|
||||
probabilities = scores.softmax(dim=-1)
|
||||
entropy = torch.distributions.Categorical(logits=scores).entropy()
|
||||
eta = torch.min(self.epsilon, torch.sqrt(self.epsilon) * torch.exp(-entropy))[..., None]
|
||||
@@ -1530,6 +1539,8 @@ class ForcedEOSTokenLogitsProcessor(LogitsProcessor):
|
||||
The maximum length of the sequence to be generated.
|
||||
eos_token_id (`Union[int, List[int], torch.Tensor]`):
|
||||
The id(s) of the *end-of-sequence* token.
|
||||
device (`str`, *optional*, defaults to `"cpu"`):
|
||||
The device to allocate the tensors.
|
||||
|
||||
Examples:
|
||||
|
||||
@@ -1553,13 +1564,13 @@ class ForcedEOSTokenLogitsProcessor(LogitsProcessor):
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, max_length: int, eos_token_id: Union[int, List[int], torch.Tensor]):
|
||||
def __init__(self, max_length: int, eos_token_id: Union[int, List[int], torch.Tensor], device: str = "cpu"):
|
||||
self.max_length = max_length
|
||||
|
||||
if not isinstance(eos_token_id, torch.Tensor):
|
||||
if isinstance(eos_token_id, int):
|
||||
eos_token_id = [eos_token_id]
|
||||
eos_token_id = torch.tensor(eos_token_id)
|
||||
eos_token_id = torch.tensor(eos_token_id, device=device)
|
||||
self.eos_token_id = eos_token_id
|
||||
|
||||
if torch.is_floating_point(eos_token_id) or (eos_token_id < 0).any():
|
||||
@@ -1568,7 +1579,6 @@ class ForcedEOSTokenLogitsProcessor(LogitsProcessor):
|
||||
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
cur_len = input_ids.shape[-1]
|
||||
self.eos_token_id = self.eos_token_id.to(scores.device)
|
||||
scores_processed = scores
|
||||
if cur_len == self.max_length - 1:
|
||||
scores_processed = torch.full_like(scores, -math.inf)
|
||||
@@ -1770,8 +1780,8 @@ class SuppressTokensAtBeginLogitsProcessor(LogitsProcessor):
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, begin_suppress_tokens, begin_index):
|
||||
self.begin_suppress_tokens = torch.tensor(list(begin_suppress_tokens))
|
||||
def __init__(self, begin_suppress_tokens, begin_index, device: str = "cpu"):
|
||||
self.begin_suppress_tokens = torch.tensor(list(begin_suppress_tokens), device=device)
|
||||
self.begin_index = begin_index
|
||||
|
||||
def set_begin_index(self, begin_index):
|
||||
@@ -1780,7 +1790,6 @@ class SuppressTokensAtBeginLogitsProcessor(LogitsProcessor):
|
||||
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
vocab_tensor = torch.arange(scores.shape[-1], device=scores.device)
|
||||
self.begin_suppress_tokens = self.begin_suppress_tokens.to(scores.device)
|
||||
suppress_token_mask = torch.isin(vocab_tensor, self.begin_suppress_tokens)
|
||||
scores_processed = scores
|
||||
if input_ids.shape[-1] == self.begin_index:
|
||||
@@ -1818,13 +1827,12 @@ class SuppressTokensLogitsProcessor(LogitsProcessor):
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, suppress_tokens):
|
||||
self.suppress_tokens = torch.tensor(list(suppress_tokens))
|
||||
def __init__(self, suppress_tokens, device: str = "cpu"):
|
||||
self.suppress_tokens = torch.tensor(list(suppress_tokens), device=device)
|
||||
|
||||
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
vocab_tensor = torch.arange(scores.shape[-1], device=scores.device)
|
||||
self.suppress_tokens = self.suppress_tokens.to(scores.device)
|
||||
suppress_token_mask = torch.isin(vocab_tensor, self.suppress_tokens)
|
||||
scores = torch.where(suppress_token_mask, -float("inf"), scores)
|
||||
return scores
|
||||
@@ -1915,7 +1923,10 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, generate_config, begin_index: Optional[int] = None, _detect_timestamp_from_logprob: Optional[bool] = None
|
||||
self,
|
||||
generate_config,
|
||||
begin_index: Optional[int] = None,
|
||||
_detect_timestamp_from_logprob: Optional[bool] = None,
|
||||
): # support for the kwargs
|
||||
self.no_timestamps_token_id = generate_config.no_timestamps_token_id
|
||||
self.timestamp_begin = generate_config.no_timestamps_token_id + 1
|
||||
@@ -2292,11 +2303,11 @@ class BarkEosPrioritizerLogitsProcessor(LogitsProcessor):
|
||||
Minimum end of speech threshold.
|
||||
"""
|
||||
|
||||
def __init__(self, eos_token_id: Union[int, List[int], torch.Tensor], min_eos_p: float):
|
||||
def __init__(self, eos_token_id: Union[int, List[int], torch.Tensor], min_eos_p: float, device: str = "cpu"):
|
||||
if not isinstance(eos_token_id, torch.Tensor):
|
||||
if isinstance(eos_token_id, int):
|
||||
eos_token_id = [eos_token_id]
|
||||
eos_token_id = torch.tensor(eos_token_id)
|
||||
eos_token_id = torch.tensor(eos_token_id, device=device)
|
||||
self.eos_token_id = eos_token_id
|
||||
|
||||
if torch.is_floating_point(eos_token_id) or (eos_token_id < 0).any():
|
||||
@@ -2309,7 +2320,6 @@ class BarkEosPrioritizerLogitsProcessor(LogitsProcessor):
|
||||
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
scores_processed = scores
|
||||
self.eos_token_id = self.eos_token_id.to(scores.device)
|
||||
if self.min_eos_p:
|
||||
probs = torch.nn.functional.softmax(scores.float(), dim=-1)
|
||||
# create scores full of -inf except for the eos_token_id
|
||||
|
||||
@@ -723,6 +723,7 @@ class GenerationMixin:
|
||||
def _get_logits_warper(
|
||||
self,
|
||||
generation_config: GenerationConfig,
|
||||
device: str,
|
||||
) -> LogitsProcessorList:
|
||||
"""
|
||||
This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsWarper`] instances
|
||||
@@ -765,7 +766,9 @@ class GenerationMixin:
|
||||
)
|
||||
if generation_config.eta_cutoff is not None and 0.0 < generation_config.eta_cutoff < 1.0:
|
||||
warpers.append(
|
||||
EtaLogitsWarper(epsilon=generation_config.eta_cutoff, min_tokens_to_keep=min_tokens_to_keep)
|
||||
EtaLogitsWarper(
|
||||
epsilon=generation_config.eta_cutoff, min_tokens_to_keep=min_tokens_to_keep, device=device
|
||||
)
|
||||
)
|
||||
# `LogitNormalization` should always be the last logit processor, when present
|
||||
if generation_config.renormalize_logits is True:
|
||||
@@ -818,7 +821,8 @@ class GenerationMixin:
|
||||
):
|
||||
processors.append(
|
||||
EncoderRepetitionPenaltyLogitsProcessor(
|
||||
penalty=generation_config.encoder_repetition_penalty, encoder_input_ids=encoder_input_ids
|
||||
penalty=generation_config.encoder_repetition_penalty,
|
||||
encoder_input_ids=encoder_input_ids,
|
||||
)
|
||||
)
|
||||
if generation_config.repetition_penalty is not None and generation_config.repetition_penalty != 1.0:
|
||||
@@ -830,18 +834,30 @@ class GenerationMixin:
|
||||
and generation_config.encoder_no_repeat_ngram_size > 0
|
||||
):
|
||||
processors.append(
|
||||
EncoderNoRepeatNGramLogitsProcessor(generation_config.encoder_no_repeat_ngram_size, encoder_input_ids)
|
||||
EncoderNoRepeatNGramLogitsProcessor(
|
||||
generation_config.encoder_no_repeat_ngram_size,
|
||||
encoder_input_ids,
|
||||
)
|
||||
)
|
||||
if generation_config.bad_words_ids is not None:
|
||||
processors.append(
|
||||
NoBadWordsLogitsProcessor(generation_config.bad_words_ids, generation_config.eos_token_id)
|
||||
NoBadWordsLogitsProcessor(
|
||||
generation_config.bad_words_ids,
|
||||
generation_config.eos_token_id,
|
||||
)
|
||||
)
|
||||
if (
|
||||
generation_config.min_length is not None
|
||||
and generation_config.eos_token_id is not None
|
||||
and generation_config.min_length > 0
|
||||
):
|
||||
processors.append(MinLengthLogitsProcessor(generation_config.min_length, generation_config.eos_token_id))
|
||||
processors.append(
|
||||
MinLengthLogitsProcessor(
|
||||
generation_config.min_length,
|
||||
generation_config.eos_token_id,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
if (
|
||||
generation_config.min_new_tokens is not None
|
||||
and generation_config.eos_token_id is not None
|
||||
@@ -849,20 +865,32 @@ class GenerationMixin:
|
||||
):
|
||||
processors.append(
|
||||
MinNewTokensLengthLogitsProcessor(
|
||||
input_ids_seq_length, generation_config.min_new_tokens, generation_config.eos_token_id
|
||||
input_ids_seq_length,
|
||||
generation_config.min_new_tokens,
|
||||
generation_config.eos_token_id,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
if prefix_allowed_tokens_fn is not None:
|
||||
processors.append(
|
||||
PrefixConstrainedLogitsProcessor(
|
||||
prefix_allowed_tokens_fn, generation_config.num_beams // generation_config.num_beam_groups
|
||||
prefix_allowed_tokens_fn,
|
||||
generation_config.num_beams // generation_config.num_beam_groups,
|
||||
)
|
||||
)
|
||||
if generation_config.forced_bos_token_id is not None:
|
||||
processors.append(ForcedBOSTokenLogitsProcessor(generation_config.forced_bos_token_id))
|
||||
processors.append(
|
||||
ForcedBOSTokenLogitsProcessor(
|
||||
generation_config.forced_bos_token_id,
|
||||
)
|
||||
)
|
||||
if generation_config.forced_eos_token_id is not None:
|
||||
processors.append(
|
||||
ForcedEOSTokenLogitsProcessor(generation_config.max_length, generation_config.forced_eos_token_id)
|
||||
ForcedEOSTokenLogitsProcessor(
|
||||
generation_config.max_length,
|
||||
generation_config.forced_eos_token_id,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
if generation_config.remove_invalid_values is True:
|
||||
processors.append(InfNanRemoveLogitsProcessor())
|
||||
@@ -875,7 +903,12 @@ class GenerationMixin:
|
||||
)
|
||||
)
|
||||
if generation_config.suppress_tokens is not None:
|
||||
processors.append(SuppressTokensLogitsProcessor(generation_config.suppress_tokens))
|
||||
processors.append(
|
||||
SuppressTokensLogitsProcessor(
|
||||
generation_config.suppress_tokens,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
if generation_config.begin_suppress_tokens is not None:
|
||||
begin_index = input_ids_seq_length
|
||||
begin_index = (
|
||||
@@ -887,7 +920,11 @@ class GenerationMixin:
|
||||
# generation starts after the last token that is forced
|
||||
begin_index += generation_config.forced_decoder_ids[-1][0]
|
||||
processors.append(
|
||||
SuppressTokensAtBeginLogitsProcessor(generation_config.begin_suppress_tokens, begin_index)
|
||||
SuppressTokensAtBeginLogitsProcessor(
|
||||
generation_config.begin_suppress_tokens,
|
||||
begin_index,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
if generation_config.forced_decoder_ids is not None:
|
||||
# TODO(Sanchit): deprecate in v4.40 by removing this logic
|
||||
@@ -1779,7 +1816,12 @@ class GenerationMixin:
|
||||
|
||||
# 12. prepare logits warper (if `do_sample` is `True`)
|
||||
prepared_logits_warper = (
|
||||
self._get_logits_warper(generation_config) if generation_config.do_sample else None
|
||||
self._get_logits_warper(
|
||||
generation_config,
|
||||
device=input_ids.device,
|
||||
)
|
||||
if generation_config.do_sample
|
||||
else None
|
||||
)
|
||||
|
||||
# 13. run assisted generate
|
||||
@@ -1812,7 +1854,9 @@ class GenerationMixin:
|
||||
elif generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH):
|
||||
# 11. prepare logits warper
|
||||
prepared_logits_warper = (
|
||||
self._get_logits_warper(generation_config) if generation_config.do_sample else None
|
||||
self._get_logits_warper(generation_config, device=input_ids.device)
|
||||
if generation_config.do_sample
|
||||
else None
|
||||
)
|
||||
|
||||
# 12. expand input_ids with `num_return_sequences` additional sequences per batch
|
||||
@@ -1838,7 +1882,9 @@ class GenerationMixin:
|
||||
elif generation_mode in (GenerationMode.BEAM_SAMPLE, GenerationMode.BEAM_SEARCH):
|
||||
# 11. prepare logits warper
|
||||
prepared_logits_warper = (
|
||||
self._get_logits_warper(generation_config) if generation_config.do_sample else None
|
||||
self._get_logits_warper(generation_config, device=input_ids.device)
|
||||
if generation_config.do_sample
|
||||
else None
|
||||
)
|
||||
|
||||
# 12. prepare beam search scorer
|
||||
|
||||
@@ -1729,6 +1729,7 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
|
||||
encoder_input_ids=input_ids,
|
||||
prefix_allowed_tokens_fn=None,
|
||||
logits_processor=logits_processor,
|
||||
device=input_ids.device,
|
||||
)
|
||||
|
||||
# 10. prepare stopping criteria
|
||||
@@ -1756,7 +1757,7 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
|
||||
|
||||
elif is_sample_gen_mode:
|
||||
# 11. prepare logits warper
|
||||
logits_warper = self._get_logits_warper(generation_config)
|
||||
logits_warper = self._get_logits_warper(generation_config, device=input_ids.device)
|
||||
|
||||
# expand input_ids with `num_return_sequences` additional sequences per batch
|
||||
input_ids, model_kwargs = self._expand_inputs_for_generation(
|
||||
@@ -2822,6 +2823,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
|
||||
encoder_input_ids=inputs_tensor,
|
||||
prefix_allowed_tokens_fn=None,
|
||||
logits_processor=logits_processor,
|
||||
device=input_ids.device,
|
||||
)
|
||||
|
||||
# 10. prepare stopping criteria
|
||||
@@ -2849,7 +2851,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
|
||||
|
||||
elif is_sample_gen_mode:
|
||||
# 11. prepare logits warper
|
||||
logits_warper = self._get_logits_warper(generation_config)
|
||||
logits_warper = self._get_logits_warper(generation_config, device=input_ids.device)
|
||||
|
||||
# expand input_ids with `num_return_sequences` additional sequences per batch
|
||||
input_ids, model_kwargs = self._expand_inputs_for_generation(
|
||||
|
||||
@@ -1666,6 +1666,7 @@ class MusicgenMelodyForCausalLM(MusicgenMelodyPreTrainedModel):
|
||||
encoder_input_ids=input_ids,
|
||||
prefix_allowed_tokens_fn=None,
|
||||
logits_processor=logits_processor,
|
||||
device=input_ids.device,
|
||||
)
|
||||
|
||||
# 10. prepare stopping criteria
|
||||
@@ -1693,7 +1694,7 @@ class MusicgenMelodyForCausalLM(MusicgenMelodyPreTrainedModel):
|
||||
|
||||
elif is_sample_gen_mode:
|
||||
# 11. prepare logits warper
|
||||
logits_warper = self._get_logits_warper(generation_config)
|
||||
logits_warper = self._get_logits_warper(generation_config, device=input_ids.device)
|
||||
|
||||
# expand input_ids with `num_return_sequences` additional sequences per batch
|
||||
input_ids, model_kwargs = self._expand_inputs_for_generation(
|
||||
@@ -2681,6 +2682,7 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel):
|
||||
encoder_input_ids=inputs_tensor,
|
||||
prefix_allowed_tokens_fn=None,
|
||||
logits_processor=logits_processor,
|
||||
device=input_ids.device,
|
||||
)
|
||||
|
||||
# 10. prepare stopping criteria
|
||||
@@ -2708,7 +2710,7 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel):
|
||||
|
||||
elif is_sample_gen_mode:
|
||||
# 11. prepare logits warper
|
||||
logits_warper = self._get_logits_warper(generation_config)
|
||||
logits_warper = self._get_logits_warper(generation_config, device=input_ids.device)
|
||||
|
||||
# expand input_ids with `num_return_sequences` additional sequences per batch
|
||||
input_ids, model_kwargs = self._expand_inputs_for_generation(
|
||||
|
||||
@@ -1538,6 +1538,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
||||
encoder_input_ids=context_input_ids,
|
||||
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
|
||||
logits_processor=logits_processor,
|
||||
device=input_ids.device,
|
||||
)
|
||||
|
||||
prepared_stopping_criteria = self._get_stopping_criteria(
|
||||
|
||||
@@ -548,13 +548,15 @@ class WhisperGenerationMixin:
|
||||
self._check_decoder_input_ids(kwargs=kwargs)
|
||||
|
||||
# 3. Retrieve logits processors
|
||||
device = kwargs["encoder_outputs"][0].device if "encoder_outputs" in kwargs else input_features.device
|
||||
begin_index = init_tokens.shape[1]
|
||||
logits_processor = self._retrieve_logit_processors(
|
||||
generation_config=generation_config,
|
||||
logits_processor=logits_processor,
|
||||
begin_index=begin_index, # begin index is index of first generated decoder token
|
||||
is_shortform=is_shortform,
|
||||
num_beams=generation_config.num_beams,
|
||||
num_beams=kwargs.get("num_beams", 1),
|
||||
device=device,
|
||||
)
|
||||
|
||||
# 5. If we're in shortform mode, simple generate the whole input at once and return the output
|
||||
@@ -1400,7 +1402,9 @@ class WhisperGenerationMixin:
|
||||
|
||||
return max_frames, seek
|
||||
|
||||
def _retrieve_logit_processors(self, generation_config, logits_processor, begin_index, is_shortform, num_beams):
|
||||
def _retrieve_logit_processors(
|
||||
self, generation_config, logits_processor, begin_index, is_shortform, num_beams, device
|
||||
):
|
||||
if generation_config.return_timestamps is True:
|
||||
timestamp_processor = WhisperTimeStampLogitsProcessor(generation_config, begin_index=begin_index)
|
||||
logits_processor = (
|
||||
@@ -1408,7 +1412,7 @@ class WhisperGenerationMixin:
|
||||
)
|
||||
|
||||
if generation_config.suppress_tokens is not None:
|
||||
suppress_tokens_processor = SuppressTokensLogitsProcessor(generation_config.suppress_tokens)
|
||||
suppress_tokens_processor = SuppressTokensLogitsProcessor(generation_config.suppress_tokens, device=device)
|
||||
logits_processor = (
|
||||
[suppress_tokens_processor]
|
||||
if logits_processor is None
|
||||
@@ -1418,7 +1422,7 @@ class WhisperGenerationMixin:
|
||||
|
||||
if generation_config.begin_suppress_tokens is not None:
|
||||
begin_suppress_processor = SuppressTokensAtBeginLogitsProcessor(
|
||||
generation_config.begin_suppress_tokens, begin_index=begin_index
|
||||
generation_config.begin_suppress_tokens, begin_index=begin_index, device=device
|
||||
)
|
||||
logits_processor = (
|
||||
[begin_suppress_processor]
|
||||
|
||||
@@ -69,7 +69,7 @@ class LogitsProcessorTest(unittest.TestCase):
|
||||
batch_size = 4
|
||||
eos_token_id = 0
|
||||
|
||||
min_dist_processor = MinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id)
|
||||
min_dist_processor = MinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id, device=torch_device)
|
||||
|
||||
# check that min length is applied at length 5
|
||||
input_ids = ids_tensor((batch_size, 5), vocab_size=20)
|
||||
@@ -91,7 +91,7 @@ class LogitsProcessorTest(unittest.TestCase):
|
||||
# check that first input is skipped (min new length applying)
|
||||
input_ids = ids_tensor((batch_size, 5), vocab_size=20)
|
||||
new_min_dist_processor = MinNewTokensLengthLogitsProcessor(
|
||||
prompt_length_to_skip=input_ids.shape[-1], min_new_tokens=3, eos_token_id=eos_token_id
|
||||
prompt_length_to_skip=input_ids.shape[-1], min_new_tokens=3, eos_token_id=eos_token_id, device=torch_device
|
||||
)
|
||||
|
||||
expected_eos_scores_before_min_length = batch_size * [-float("inf")]
|
||||
@@ -450,7 +450,7 @@ class LogitsProcessorTest(unittest.TestCase):
|
||||
torch.tensor([[0.0, 0.1, 0.8, 0.1], [0.01, 0.04, 0.9, 0.05]], device=torch_device, dtype=torch.float)
|
||||
)
|
||||
|
||||
eta_warp = EtaLogitsWarper(0.0625)
|
||||
eta_warp = EtaLogitsWarper(0.0625, device=torch_device)
|
||||
filtered_dist = torch.exp(eta_warp(input_ids, dist))
|
||||
|
||||
# dist should be filtered to only keep values with proba >= min(0.0625, sqrt(0.0625) * e^-H(p))
|
||||
@@ -474,7 +474,7 @@ class LogitsProcessorTest(unittest.TestCase):
|
||||
ramp_logits[1] = ramp_logits[1] * 100.0
|
||||
|
||||
# make sure at least 2 tokens are kept
|
||||
eta_warp = EtaLogitsWarper(0.1, min_tokens_to_keep=2, filter_value=0.0)
|
||||
eta_warp = EtaLogitsWarper(0.1, min_tokens_to_keep=2, filter_value=0.0, device=torch_device)
|
||||
filtered_dist = eta_warp(input_ids, ramp_logits)
|
||||
|
||||
# first batch should keep 2 tokens, second batch would keep only 1, but due to `min_tokens_to_keep=2` keeps 2.
|
||||
@@ -640,7 +640,7 @@ class LogitsProcessorTest(unittest.TestCase):
|
||||
scores_comp = scores.clone()
|
||||
|
||||
# instantiate all dist processors
|
||||
min_dist_proc = MinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id)
|
||||
min_dist_proc = MinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id, device=torch_device)
|
||||
temp_dist_warp = TemperatureLogitsWarper(temperature=0.5)
|
||||
rep_penalty_proc = RepetitionPenaltyLogitsProcessor(penalty=2.0)
|
||||
top_k_warp = TopKLogitsWarper(3)
|
||||
@@ -767,7 +767,9 @@ class LogitsProcessorTest(unittest.TestCase):
|
||||
eos_token_id = 0
|
||||
max_length = 5
|
||||
|
||||
logits_processor = ForcedEOSTokenLogitsProcessor(max_length=max_length, eos_token_id=eos_token_id)
|
||||
logits_processor = ForcedEOSTokenLogitsProcessor(
|
||||
max_length=max_length, eos_token_id=eos_token_id, device=torch_device
|
||||
)
|
||||
|
||||
# check that all scores are -inf except the eos_token_id when max_length-1 is reached
|
||||
input_ids = ids_tensor((batch_size, 4), vocab_size=20)
|
||||
@@ -927,7 +929,7 @@ class LogitsProcessorTest(unittest.TestCase):
|
||||
scores = self._get_uniform_logits(2, 4)
|
||||
scores[0][eos_token_id] = -6 ## less than log(min_eos_p)
|
||||
|
||||
esp = BarkEosPrioritizerLogitsProcessor(eos_token_id=eos_token_id, min_eos_p=min_eos_p)
|
||||
esp = BarkEosPrioritizerLogitsProcessor(eos_token_id=eos_token_id, min_eos_p=min_eos_p, device=torch_device)
|
||||
actual_scores = esp(input_ids, scores)
|
||||
expected_scores_list = [
|
||||
scores[0].tolist(),
|
||||
@@ -943,7 +945,7 @@ class LogitsProcessorTest(unittest.TestCase):
|
||||
scores = self._get_uniform_logits(2, 4)
|
||||
scores[0][eos_token_id] = -6 ## less than log(min_eos_p)
|
||||
|
||||
esp = BarkEosPrioritizerLogitsProcessor(eos_token_id=eos_token_id, min_eos_p=min_eos_p)
|
||||
esp = BarkEosPrioritizerLogitsProcessor(eos_token_id=eos_token_id, min_eos_p=min_eos_p, device=torch_device)
|
||||
actual_scores = esp(input_ids, scores)
|
||||
expected_scores_list = [
|
||||
scores[0].tolist(),
|
||||
|
||||
Reference in New Issue
Block a user