From 83238eeebc77e5a533803d72066f988f4707b9d5 Mon Sep 17 00:00:00 2001 From: Raushan Turganbay Date: Tue, 4 Jun 2024 10:19:19 +0500 Subject: [PATCH] Pass device in Logits Processor's init (#29804) * add device in logits processor * remove device when not needed * codestyle * tests * forgot `melody` version * Update src/transformers/models/whisper/generation_whisper.py Co-authored-by: Joao Gante * codestyle * updates --------- Co-authored-by: Joao Gante --- src/transformers/generation/logits_process.py | 54 ++++++++------ src/transformers/generation/utils.py | 74 +++++++++++++++---- .../models/musicgen/modeling_musicgen.py | 6 +- .../modeling_musicgen_melody.py | 6 +- src/transformers/models/rag/modeling_rag.py | 1 + .../models/whisper/generation_whisper.py | 12 ++- tests/generation/test_logits_process.py | 18 +++-- 7 files changed, 119 insertions(+), 52 deletions(-) diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index d870446504..b226a059d1 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -110,6 +110,8 @@ class MinLengthLogitsProcessor(LogitsProcessor): The minimum length below which the score of `eos_token_id` is set to `-float("Inf")`. eos_token_id (`Union[int, List[int], torch.Tensor]`): The id(s) of the *end-of-sequence* token. + device (`str`, *optional*, defaults to `"cpu"`): + The device to allocate the tensors. Examples: @@ -137,14 +139,14 @@ class MinLengthLogitsProcessor(LogitsProcessor): ``` """ - def __init__(self, min_length: int, eos_token_id: Union[int, List[int], torch.Tensor]): + def __init__(self, min_length: int, eos_token_id: Union[int, List[int], torch.Tensor], device: str = "cpu"): if not isinstance(min_length, int) or min_length < 0: raise ValueError(f"`min_length` has to be a non-negative integer, but is {min_length}") if not isinstance(eos_token_id, torch.Tensor): if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] - eos_token_id = torch.tensor(eos_token_id) + eos_token_id = torch.tensor(eos_token_id, device=device) self.min_length = min_length self.eos_token_id = eos_token_id @@ -152,7 +154,6 @@ class MinLengthLogitsProcessor(LogitsProcessor): @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: vocab_tensor = torch.arange(scores.shape[-1], device=scores.device) - self.eos_token_id = self.eos_token_id.to(scores.device) eos_token_mask = torch.isin(vocab_tensor, self.eos_token_id) scores_processed = scores.clone() if input_ids.shape[-1] < self.min_length: @@ -173,6 +174,8 @@ class MinNewTokensLengthLogitsProcessor(LogitsProcessor): The minimum *new* tokens length below which the score of `eos_token_id` is set to `-float("Inf")`. eos_token_id (`Union[int, List[int], torch.Tensor]`): The id(s) of the *end-of-sequence* token. + device (`str`, *optional*, defaults to `"cpu"`): + The device to allocate the tensors. Examples: @@ -196,7 +199,11 @@ class MinNewTokensLengthLogitsProcessor(LogitsProcessor): """ def __init__( - self, prompt_length_to_skip: int, min_new_tokens: int, eos_token_id: Union[int, List[int], torch.Tensor] + self, + prompt_length_to_skip: int, + min_new_tokens: int, + eos_token_id: Union[int, List[int], torch.Tensor], + device: str = "cpu", ): for arg_name, arg_value in [ ("prompt_length_to_skip", prompt_length_to_skip), @@ -208,7 +215,7 @@ class MinNewTokensLengthLogitsProcessor(LogitsProcessor): if not isinstance(eos_token_id, torch.Tensor): if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] - eos_token_id = torch.tensor(eos_token_id) + eos_token_id = torch.tensor(eos_token_id, device=device) self.prompt_length_to_skip = prompt_length_to_skip self.min_new_tokens = min_new_tokens @@ -219,7 +226,6 @@ class MinNewTokensLengthLogitsProcessor(LogitsProcessor): new_tokens_length = input_ids.shape[-1] - self.prompt_length_to_skip scores_processed = scores.clone() vocab_tensor = torch.arange(scores.shape[-1], device=scores.device) - self.eos_token_id = self.eos_token_id.to(scores.device) eos_token_mask = torch.isin(vocab_tensor, self.eos_token_id) if new_tokens_length < self.min_new_tokens: scores_processed = torch.where(eos_token_mask, -math.inf, scores) @@ -779,6 +785,8 @@ class EtaLogitsWarper(LogitsWarper): Specifies the minimum number of tokens that must be kept for generation, regardless of their probabilities. For example, if `min_tokens_to_keep` is set to 1, at least one token will always be kept for generation, even if all tokens have probabilities below the cutoff `eta`. + device (`str`, *optional*, defaults to `"cpu"`): + The device to allocate the tensors. Examples: ```python @@ -806,7 +814,9 @@ class EtaLogitsWarper(LogitsWarper): ``` """ - def __init__(self, epsilon: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1): + def __init__( + self, epsilon: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1, device: str = "cpu" + ): epsilon = float(epsilon) if epsilon <= 0 or epsilon >= 1: raise ValueError(f"`eta_cutoff` has to be a float > 0 and < 1, but is {epsilon}") @@ -817,13 +827,12 @@ class EtaLogitsWarper(LogitsWarper): f"`min_tokens_to_keep` has to be a strictly positive integer, but is {min_tokens_to_keep}" ) - self.epsilon = torch.tensor(epsilon) + self.epsilon = torch.tensor(epsilon, device=device) self.filter_value = filter_value self.min_tokens_to_keep = min_tokens_to_keep @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: - # Calculate the adaptive cutoff probabilities = scores.softmax(dim=-1) entropy = torch.distributions.Categorical(logits=scores).entropy() eta = torch.min(self.epsilon, torch.sqrt(self.epsilon) * torch.exp(-entropy))[..., None] @@ -1530,6 +1539,8 @@ class ForcedEOSTokenLogitsProcessor(LogitsProcessor): The maximum length of the sequence to be generated. eos_token_id (`Union[int, List[int], torch.Tensor]`): The id(s) of the *end-of-sequence* token. + device (`str`, *optional*, defaults to `"cpu"`): + The device to allocate the tensors. Examples: @@ -1553,13 +1564,13 @@ class ForcedEOSTokenLogitsProcessor(LogitsProcessor): ``` """ - def __init__(self, max_length: int, eos_token_id: Union[int, List[int], torch.Tensor]): + def __init__(self, max_length: int, eos_token_id: Union[int, List[int], torch.Tensor], device: str = "cpu"): self.max_length = max_length if not isinstance(eos_token_id, torch.Tensor): if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] - eos_token_id = torch.tensor(eos_token_id) + eos_token_id = torch.tensor(eos_token_id, device=device) self.eos_token_id = eos_token_id if torch.is_floating_point(eos_token_id) or (eos_token_id < 0).any(): @@ -1568,7 +1579,6 @@ class ForcedEOSTokenLogitsProcessor(LogitsProcessor): @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: cur_len = input_ids.shape[-1] - self.eos_token_id = self.eos_token_id.to(scores.device) scores_processed = scores if cur_len == self.max_length - 1: scores_processed = torch.full_like(scores, -math.inf) @@ -1770,8 +1780,8 @@ class SuppressTokensAtBeginLogitsProcessor(LogitsProcessor): ``` """ - def __init__(self, begin_suppress_tokens, begin_index): - self.begin_suppress_tokens = torch.tensor(list(begin_suppress_tokens)) + def __init__(self, begin_suppress_tokens, begin_index, device: str = "cpu"): + self.begin_suppress_tokens = torch.tensor(list(begin_suppress_tokens), device=device) self.begin_index = begin_index def set_begin_index(self, begin_index): @@ -1780,7 +1790,6 @@ class SuppressTokensAtBeginLogitsProcessor(LogitsProcessor): @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: vocab_tensor = torch.arange(scores.shape[-1], device=scores.device) - self.begin_suppress_tokens = self.begin_suppress_tokens.to(scores.device) suppress_token_mask = torch.isin(vocab_tensor, self.begin_suppress_tokens) scores_processed = scores if input_ids.shape[-1] == self.begin_index: @@ -1818,13 +1827,12 @@ class SuppressTokensLogitsProcessor(LogitsProcessor): ``` """ - def __init__(self, suppress_tokens): - self.suppress_tokens = torch.tensor(list(suppress_tokens)) + def __init__(self, suppress_tokens, device: str = "cpu"): + self.suppress_tokens = torch.tensor(list(suppress_tokens), device=device) @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: vocab_tensor = torch.arange(scores.shape[-1], device=scores.device) - self.suppress_tokens = self.suppress_tokens.to(scores.device) suppress_token_mask = torch.isin(vocab_tensor, self.suppress_tokens) scores = torch.where(suppress_token_mask, -float("inf"), scores) return scores @@ -1915,7 +1923,10 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor): """ def __init__( - self, generate_config, begin_index: Optional[int] = None, _detect_timestamp_from_logprob: Optional[bool] = None + self, + generate_config, + begin_index: Optional[int] = None, + _detect_timestamp_from_logprob: Optional[bool] = None, ): # support for the kwargs self.no_timestamps_token_id = generate_config.no_timestamps_token_id self.timestamp_begin = generate_config.no_timestamps_token_id + 1 @@ -2292,11 +2303,11 @@ class BarkEosPrioritizerLogitsProcessor(LogitsProcessor): Minimum end of speech threshold. """ - def __init__(self, eos_token_id: Union[int, List[int], torch.Tensor], min_eos_p: float): + def __init__(self, eos_token_id: Union[int, List[int], torch.Tensor], min_eos_p: float, device: str = "cpu"): if not isinstance(eos_token_id, torch.Tensor): if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] - eos_token_id = torch.tensor(eos_token_id) + eos_token_id = torch.tensor(eos_token_id, device=device) self.eos_token_id = eos_token_id if torch.is_floating_point(eos_token_id) or (eos_token_id < 0).any(): @@ -2309,7 +2320,6 @@ class BarkEosPrioritizerLogitsProcessor(LogitsProcessor): @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: scores_processed = scores - self.eos_token_id = self.eos_token_id.to(scores.device) if self.min_eos_p: probs = torch.nn.functional.softmax(scores.float(), dim=-1) # create scores full of -inf except for the eos_token_id diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 970bcdd586..47ca012f22 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -723,6 +723,7 @@ class GenerationMixin: def _get_logits_warper( self, generation_config: GenerationConfig, + device: str, ) -> LogitsProcessorList: """ This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsWarper`] instances @@ -765,7 +766,9 @@ class GenerationMixin: ) if generation_config.eta_cutoff is not None and 0.0 < generation_config.eta_cutoff < 1.0: warpers.append( - EtaLogitsWarper(epsilon=generation_config.eta_cutoff, min_tokens_to_keep=min_tokens_to_keep) + EtaLogitsWarper( + epsilon=generation_config.eta_cutoff, min_tokens_to_keep=min_tokens_to_keep, device=device + ) ) # `LogitNormalization` should always be the last logit processor, when present if generation_config.renormalize_logits is True: @@ -818,7 +821,8 @@ class GenerationMixin: ): processors.append( EncoderRepetitionPenaltyLogitsProcessor( - penalty=generation_config.encoder_repetition_penalty, encoder_input_ids=encoder_input_ids + penalty=generation_config.encoder_repetition_penalty, + encoder_input_ids=encoder_input_ids, ) ) if generation_config.repetition_penalty is not None and generation_config.repetition_penalty != 1.0: @@ -830,18 +834,30 @@ class GenerationMixin: and generation_config.encoder_no_repeat_ngram_size > 0 ): processors.append( - EncoderNoRepeatNGramLogitsProcessor(generation_config.encoder_no_repeat_ngram_size, encoder_input_ids) + EncoderNoRepeatNGramLogitsProcessor( + generation_config.encoder_no_repeat_ngram_size, + encoder_input_ids, + ) ) if generation_config.bad_words_ids is not None: processors.append( - NoBadWordsLogitsProcessor(generation_config.bad_words_ids, generation_config.eos_token_id) + NoBadWordsLogitsProcessor( + generation_config.bad_words_ids, + generation_config.eos_token_id, + ) ) if ( generation_config.min_length is not None and generation_config.eos_token_id is not None and generation_config.min_length > 0 ): - processors.append(MinLengthLogitsProcessor(generation_config.min_length, generation_config.eos_token_id)) + processors.append( + MinLengthLogitsProcessor( + generation_config.min_length, + generation_config.eos_token_id, + device=device, + ) + ) if ( generation_config.min_new_tokens is not None and generation_config.eos_token_id is not None @@ -849,20 +865,32 @@ class GenerationMixin: ): processors.append( MinNewTokensLengthLogitsProcessor( - input_ids_seq_length, generation_config.min_new_tokens, generation_config.eos_token_id + input_ids_seq_length, + generation_config.min_new_tokens, + generation_config.eos_token_id, + device=device, ) ) if prefix_allowed_tokens_fn is not None: processors.append( PrefixConstrainedLogitsProcessor( - prefix_allowed_tokens_fn, generation_config.num_beams // generation_config.num_beam_groups + prefix_allowed_tokens_fn, + generation_config.num_beams // generation_config.num_beam_groups, ) ) if generation_config.forced_bos_token_id is not None: - processors.append(ForcedBOSTokenLogitsProcessor(generation_config.forced_bos_token_id)) + processors.append( + ForcedBOSTokenLogitsProcessor( + generation_config.forced_bos_token_id, + ) + ) if generation_config.forced_eos_token_id is not None: processors.append( - ForcedEOSTokenLogitsProcessor(generation_config.max_length, generation_config.forced_eos_token_id) + ForcedEOSTokenLogitsProcessor( + generation_config.max_length, + generation_config.forced_eos_token_id, + device=device, + ) ) if generation_config.remove_invalid_values is True: processors.append(InfNanRemoveLogitsProcessor()) @@ -875,7 +903,12 @@ class GenerationMixin: ) ) if generation_config.suppress_tokens is not None: - processors.append(SuppressTokensLogitsProcessor(generation_config.suppress_tokens)) + processors.append( + SuppressTokensLogitsProcessor( + generation_config.suppress_tokens, + device=device, + ) + ) if generation_config.begin_suppress_tokens is not None: begin_index = input_ids_seq_length begin_index = ( @@ -887,7 +920,11 @@ class GenerationMixin: # generation starts after the last token that is forced begin_index += generation_config.forced_decoder_ids[-1][0] processors.append( - SuppressTokensAtBeginLogitsProcessor(generation_config.begin_suppress_tokens, begin_index) + SuppressTokensAtBeginLogitsProcessor( + generation_config.begin_suppress_tokens, + begin_index, + device=device, + ) ) if generation_config.forced_decoder_ids is not None: # TODO(Sanchit): deprecate in v4.40 by removing this logic @@ -1779,7 +1816,12 @@ class GenerationMixin: # 12. prepare logits warper (if `do_sample` is `True`) prepared_logits_warper = ( - self._get_logits_warper(generation_config) if generation_config.do_sample else None + self._get_logits_warper( + generation_config, + device=input_ids.device, + ) + if generation_config.do_sample + else None ) # 13. run assisted generate @@ -1812,7 +1854,9 @@ class GenerationMixin: elif generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH): # 11. prepare logits warper prepared_logits_warper = ( - self._get_logits_warper(generation_config) if generation_config.do_sample else None + self._get_logits_warper(generation_config, device=input_ids.device) + if generation_config.do_sample + else None ) # 12. expand input_ids with `num_return_sequences` additional sequences per batch @@ -1838,7 +1882,9 @@ class GenerationMixin: elif generation_mode in (GenerationMode.BEAM_SAMPLE, GenerationMode.BEAM_SEARCH): # 11. prepare logits warper prepared_logits_warper = ( - self._get_logits_warper(generation_config) if generation_config.do_sample else None + self._get_logits_warper(generation_config, device=input_ids.device) + if generation_config.do_sample + else None ) # 12. prepare beam search scorer diff --git a/src/transformers/models/musicgen/modeling_musicgen.py b/src/transformers/models/musicgen/modeling_musicgen.py index 810f34f780..792cf937f3 100644 --- a/src/transformers/models/musicgen/modeling_musicgen.py +++ b/src/transformers/models/musicgen/modeling_musicgen.py @@ -1729,6 +1729,7 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel): encoder_input_ids=input_ids, prefix_allowed_tokens_fn=None, logits_processor=logits_processor, + device=input_ids.device, ) # 10. prepare stopping criteria @@ -1756,7 +1757,7 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel): elif is_sample_gen_mode: # 11. prepare logits warper - logits_warper = self._get_logits_warper(generation_config) + logits_warper = self._get_logits_warper(generation_config, device=input_ids.device) # expand input_ids with `num_return_sequences` additional sequences per batch input_ids, model_kwargs = self._expand_inputs_for_generation( @@ -2822,6 +2823,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel): encoder_input_ids=inputs_tensor, prefix_allowed_tokens_fn=None, logits_processor=logits_processor, + device=input_ids.device, ) # 10. prepare stopping criteria @@ -2849,7 +2851,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel): elif is_sample_gen_mode: # 11. prepare logits warper - logits_warper = self._get_logits_warper(generation_config) + logits_warper = self._get_logits_warper(generation_config, device=input_ids.device) # expand input_ids with `num_return_sequences` additional sequences per batch input_ids, model_kwargs = self._expand_inputs_for_generation( diff --git a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py index 119628d50d..6861349edc 100644 --- a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py +++ b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py @@ -1666,6 +1666,7 @@ class MusicgenMelodyForCausalLM(MusicgenMelodyPreTrainedModel): encoder_input_ids=input_ids, prefix_allowed_tokens_fn=None, logits_processor=logits_processor, + device=input_ids.device, ) # 10. prepare stopping criteria @@ -1693,7 +1694,7 @@ class MusicgenMelodyForCausalLM(MusicgenMelodyPreTrainedModel): elif is_sample_gen_mode: # 11. prepare logits warper - logits_warper = self._get_logits_warper(generation_config) + logits_warper = self._get_logits_warper(generation_config, device=input_ids.device) # expand input_ids with `num_return_sequences` additional sequences per batch input_ids, model_kwargs = self._expand_inputs_for_generation( @@ -2681,6 +2682,7 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel): encoder_input_ids=inputs_tensor, prefix_allowed_tokens_fn=None, logits_processor=logits_processor, + device=input_ids.device, ) # 10. prepare stopping criteria @@ -2708,7 +2710,7 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel): elif is_sample_gen_mode: # 11. prepare logits warper - logits_warper = self._get_logits_warper(generation_config) + logits_warper = self._get_logits_warper(generation_config, device=input_ids.device) # expand input_ids with `num_return_sequences` additional sequences per batch input_ids, model_kwargs = self._expand_inputs_for_generation( diff --git a/src/transformers/models/rag/modeling_rag.py b/src/transformers/models/rag/modeling_rag.py index 3590369d5b..4f6c8dc384 100644 --- a/src/transformers/models/rag/modeling_rag.py +++ b/src/transformers/models/rag/modeling_rag.py @@ -1538,6 +1538,7 @@ class RagTokenForGeneration(RagPreTrainedModel): encoder_input_ids=context_input_ids, prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, logits_processor=logits_processor, + device=input_ids.device, ) prepared_stopping_criteria = self._get_stopping_criteria( diff --git a/src/transformers/models/whisper/generation_whisper.py b/src/transformers/models/whisper/generation_whisper.py index 1d6639fa44..4d60427d8b 100644 --- a/src/transformers/models/whisper/generation_whisper.py +++ b/src/transformers/models/whisper/generation_whisper.py @@ -548,13 +548,15 @@ class WhisperGenerationMixin: self._check_decoder_input_ids(kwargs=kwargs) # 3. Retrieve logits processors + device = kwargs["encoder_outputs"][0].device if "encoder_outputs" in kwargs else input_features.device begin_index = init_tokens.shape[1] logits_processor = self._retrieve_logit_processors( generation_config=generation_config, logits_processor=logits_processor, begin_index=begin_index, # begin index is index of first generated decoder token is_shortform=is_shortform, - num_beams=generation_config.num_beams, + num_beams=kwargs.get("num_beams", 1), + device=device, ) # 5. If we're in shortform mode, simple generate the whole input at once and return the output @@ -1400,7 +1402,9 @@ class WhisperGenerationMixin: return max_frames, seek - def _retrieve_logit_processors(self, generation_config, logits_processor, begin_index, is_shortform, num_beams): + def _retrieve_logit_processors( + self, generation_config, logits_processor, begin_index, is_shortform, num_beams, device + ): if generation_config.return_timestamps is True: timestamp_processor = WhisperTimeStampLogitsProcessor(generation_config, begin_index=begin_index) logits_processor = ( @@ -1408,7 +1412,7 @@ class WhisperGenerationMixin: ) if generation_config.suppress_tokens is not None: - suppress_tokens_processor = SuppressTokensLogitsProcessor(generation_config.suppress_tokens) + suppress_tokens_processor = SuppressTokensLogitsProcessor(generation_config.suppress_tokens, device=device) logits_processor = ( [suppress_tokens_processor] if logits_processor is None @@ -1418,7 +1422,7 @@ class WhisperGenerationMixin: if generation_config.begin_suppress_tokens is not None: begin_suppress_processor = SuppressTokensAtBeginLogitsProcessor( - generation_config.begin_suppress_tokens, begin_index=begin_index + generation_config.begin_suppress_tokens, begin_index=begin_index, device=device ) logits_processor = ( [begin_suppress_processor] diff --git a/tests/generation/test_logits_process.py b/tests/generation/test_logits_process.py index 775e702a02..a5d3ab37ef 100644 --- a/tests/generation/test_logits_process.py +++ b/tests/generation/test_logits_process.py @@ -69,7 +69,7 @@ class LogitsProcessorTest(unittest.TestCase): batch_size = 4 eos_token_id = 0 - min_dist_processor = MinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id) + min_dist_processor = MinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id, device=torch_device) # check that min length is applied at length 5 input_ids = ids_tensor((batch_size, 5), vocab_size=20) @@ -91,7 +91,7 @@ class LogitsProcessorTest(unittest.TestCase): # check that first input is skipped (min new length applying) input_ids = ids_tensor((batch_size, 5), vocab_size=20) new_min_dist_processor = MinNewTokensLengthLogitsProcessor( - prompt_length_to_skip=input_ids.shape[-1], min_new_tokens=3, eos_token_id=eos_token_id + prompt_length_to_skip=input_ids.shape[-1], min_new_tokens=3, eos_token_id=eos_token_id, device=torch_device ) expected_eos_scores_before_min_length = batch_size * [-float("inf")] @@ -450,7 +450,7 @@ class LogitsProcessorTest(unittest.TestCase): torch.tensor([[0.0, 0.1, 0.8, 0.1], [0.01, 0.04, 0.9, 0.05]], device=torch_device, dtype=torch.float) ) - eta_warp = EtaLogitsWarper(0.0625) + eta_warp = EtaLogitsWarper(0.0625, device=torch_device) filtered_dist = torch.exp(eta_warp(input_ids, dist)) # dist should be filtered to only keep values with proba >= min(0.0625, sqrt(0.0625) * e^-H(p)) @@ -474,7 +474,7 @@ class LogitsProcessorTest(unittest.TestCase): ramp_logits[1] = ramp_logits[1] * 100.0 # make sure at least 2 tokens are kept - eta_warp = EtaLogitsWarper(0.1, min_tokens_to_keep=2, filter_value=0.0) + eta_warp = EtaLogitsWarper(0.1, min_tokens_to_keep=2, filter_value=0.0, device=torch_device) filtered_dist = eta_warp(input_ids, ramp_logits) # first batch should keep 2 tokens, second batch would keep only 1, but due to `min_tokens_to_keep=2` keeps 2. @@ -640,7 +640,7 @@ class LogitsProcessorTest(unittest.TestCase): scores_comp = scores.clone() # instantiate all dist processors - min_dist_proc = MinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id) + min_dist_proc = MinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id, device=torch_device) temp_dist_warp = TemperatureLogitsWarper(temperature=0.5) rep_penalty_proc = RepetitionPenaltyLogitsProcessor(penalty=2.0) top_k_warp = TopKLogitsWarper(3) @@ -767,7 +767,9 @@ class LogitsProcessorTest(unittest.TestCase): eos_token_id = 0 max_length = 5 - logits_processor = ForcedEOSTokenLogitsProcessor(max_length=max_length, eos_token_id=eos_token_id) + logits_processor = ForcedEOSTokenLogitsProcessor( + max_length=max_length, eos_token_id=eos_token_id, device=torch_device + ) # check that all scores are -inf except the eos_token_id when max_length-1 is reached input_ids = ids_tensor((batch_size, 4), vocab_size=20) @@ -927,7 +929,7 @@ class LogitsProcessorTest(unittest.TestCase): scores = self._get_uniform_logits(2, 4) scores[0][eos_token_id] = -6 ## less than log(min_eos_p) - esp = BarkEosPrioritizerLogitsProcessor(eos_token_id=eos_token_id, min_eos_p=min_eos_p) + esp = BarkEosPrioritizerLogitsProcessor(eos_token_id=eos_token_id, min_eos_p=min_eos_p, device=torch_device) actual_scores = esp(input_ids, scores) expected_scores_list = [ scores[0].tolist(), @@ -943,7 +945,7 @@ class LogitsProcessorTest(unittest.TestCase): scores = self._get_uniform_logits(2, 4) scores[0][eos_token_id] = -6 ## less than log(min_eos_p) - esp = BarkEosPrioritizerLogitsProcessor(eos_token_id=eos_token_id, min_eos_p=min_eos_p) + esp = BarkEosPrioritizerLogitsProcessor(eos_token_id=eos_token_id, min_eos_p=min_eos_p, device=torch_device) actual_scores = esp(input_ids, scores) expected_scores_list = [ scores[0].tolist(),