diff --git a/docs/source/en/generation_strategies.md b/docs/source/en/generation_strategies.md index 06e7e0b8ab..64ded96137 100644 --- a/docs/source/en/generation_strategies.md +++ b/docs/source/en/generation_strategies.md @@ -408,14 +408,24 @@ For the complete list of the available parameters, refer to the [API documentati ### Speculative Decoding Speculative decoding (also known as assisted decoding) is a modification of the decoding strategies above, that uses an -assistant model (ideally a much smaller one) with the same tokenizer, to generate a few candidate tokens. The main -model then validates the candidate tokens in a single forward pass, which speeds up the decoding process. If -`do_sample=True`, then the token validation with resampling introduced in the -[speculative decoding paper](https://arxiv.org/pdf/2211.17192.pdf) is used. +assistant model (ideally a much smaller one), to generate a few candidate tokens. The main model then validates the candidate +tokens in a single forward pass, which speeds up the decoding process. If `do_sample=True`, then the token validation with +resampling introduced in the [speculative decoding paper](https://arxiv.org/pdf/2211.17192.pdf) is used. +Assisted decoding assumes the main and assistant models have the same tokenizer, otherwise, see Universal Assisted Decoding below. Currently, only greedy search and sampling are supported with assisted decoding, and assisted decoding doesn't support batched inputs. To learn more about assisted decoding, check [this blog post](https://huggingface.co/blog/assisted-generation). +#### Universal Assisted Decoding + +Universal Assisted Decoding (UAD) adds support for main and assistant models with different tokenizers. +To use it, simply pass the tokenizers using the `tokenizer` and `assistant_tokenizer` arguments (see below). +Internally, the main model input tokens are re-encoded into assistant model tokens, then candidate tokens are generated in the assistant encoding, which are +in turn re-encoded into main model candidate tokens. Validation then proceeds as explained above. +The re-encoding steps involve decoding token ids into text and then encoding the text using a different tokenizer. +Since re-encoding the tokens may result in tokenization discrepancies, UAD finds the longest common subsequence between the source and target encodings, +to ensure the new tokens include the correct prompt suffix. + To enable assisted decoding, set the `assistant_model` argument with a model. ```python @@ -435,6 +445,26 @@ To enable assisted decoding, set the `assistant_model` argument with a model. ['Alice and Bob are sitting in a bar. Alice is drinking a beer and Bob is drinking a'] ``` +If the main and assistant models have different tokenizers, use Universal Assisted Decoding. + +```python +>>> from transformers import AutoModelForCausalLM, AutoTokenizer + +>>> prompt = "Alice and Bob" +>>> checkpoint = "google/gemma-2-9b" +>>> assistant_checkpoint = "double7/vicuna-68m" + +>>> assistant_tokenizer = AutoTokenizer.from_pretrained(assistant_checkpoint) +>>> tokenizer = AutoTokenizer.from_pretrained(checkpoint) +>>> inputs = tokenizer(prompt, return_tensors="pt") + +>>> model = AutoModelForCausalLM.from_pretrained(checkpoint) +>>> assistant_model = AutoModelForCausalLM.from_pretrained(assistant_checkpoint) +>>> outputs = model.generate(**inputs, assistant_model=assistant_model, tokenizer=tokenizer, assistant_tokenizer=assistant_tokenizer) +>>> tokenizer.batch_decode(outputs, skip_special_tokens=True) +['Alice and Bob are sitting in a bar. Alice is drinking a beer and Bob is drinking a'] +``` + When using assisted decoding with sampling methods, you can use the `temperature` argument to control the randomness, just like in multinomial sampling. However, in assisted decoding, reducing the temperature may help improve the latency. @@ -458,6 +488,7 @@ just like in multinomial sampling. However, in assisted decoding, reducing the t Alternatively, you can also set the `prompt_lookup_num_tokens` to trigger n-gram based assisted decoding, as opposed to model based assisted decoding. You can read more about it [here](https://twitter.com/joao_gante/status/1747322413006643259). + ### DoLa Decoding **D**ecoding by C**o**ntrasting **La**yers (DoLa) is a contrastive decoding strategy to improve the factuality and reduce the diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index fb7ed2f0b2..a4c8f79ae9 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -16,6 +16,7 @@ import copy from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple +import numpy as np import torch from ..cache_utils import DynamicCache @@ -25,6 +26,7 @@ from .logits_process import LogitsProcessorList, MinLengthLogitsProcessor if TYPE_CHECKING: from ..modeling_utils import PreTrainedModel + from ..tokenization_utils_base import PreTrainedTokenizerBase from .configuration_utils import GenerationConfig @@ -156,6 +158,7 @@ class AssistedCandidateGenerator(CandidateGenerator): # Prepare generation-related options. self.logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() self.generation_config = copy.deepcopy(generation_config) + self.generation_config.return_dict_in_generate = True self.generation_config.output_scores = True self.generation_config.assistant_confidence_threshold = self.assistant_confidence_threshold @@ -258,6 +261,303 @@ class AssistedCandidateGenerator(CandidateGenerator): self.num_assistant_tokens = max(1.0, self.num_assistant_tokens - 1.0) +class AssistedCandidateGeneratorDifferentTokenizers(AssistedCandidateGenerator): + """ + `CandidateGenerator` class to be used for Universal Assisted Generation (UAD): assisted generation with different tokenizers + for the assistant and main models. This class generates candidates through the use of a smaller + model. + + The main model input tokens are re-encoded into assistant model tokens, then candidate tokens are generated in the assistant encoding, which are + in turn re-encoded into main model candidate tokens. Validation then proceeds as explained above. + The re-encoding steps involve decoding token ids into text and then encoding the text using a different tokenizer. + Since re-encoding the tokens may result in tokenization discrepancies, UAD finds the longest common subsequence between the source and target encodings, + to ensure the new tokens include the correct prompt suffix. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids) + assistant_model (`PreTrainedModel`): + The model to be used for generating candidates. This model should be smaller than the main model. + target_tokenizer (`PreTrainedTokenizerBase`): + The tokenizer used for the target model. + assistant_tokenizer (`PreTrainedTokenizerBase`): + The tokenizer used for the assistant model. + generation_config (`~generation.GenerationConfig`, *optional*): + The generation configuration to be used as base parametrization for the generation call. + logits_processor (`LogitsProcessorList`): + An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] + used to modify the prediction scores of the language modeling head applied at each generation step. + model_kwargs (`Dict`): + The keyword arguments that will be passed to the main model, and are used as base inputs for the assistant + model as well. + inputs_tensor (`torch.Tensor`, *optional*): + The model input tensor. In encoder-decoder models, this is the encoder input. + """ + + def __init__( + self, + input_ids: torch.LongTensor, + assistant_model: "PreTrainedModel", + target_tokenizer: "PreTrainedTokenizerBase", + assistant_tokenizer: "PreTrainedTokenizerBase", + generation_config: "GenerationConfig", + model_kwargs: Dict, + inputs_tensor: Optional[torch.Tensor] = None, + logits_processor: "LogitsProcessorList" = None, + ): + super().__init__(input_ids, assistant_model, generation_config, model_kwargs, inputs_tensor, logits_processor) + + self.target_tokenizer = target_tokenizer + self.assistant_tokenizer = assistant_tokenizer + self.prev_tokens = None + self.prev_assistant_ids = None + self.target_lookbehind = 10 + self.assistant_lookbehind = 10 + + @staticmethod + def _get_longest_diag_dict(input_matrix, nonzero_idx): + """ + Calculates the length of the longest diagonal sequence in a given matrix. + Args: + input_matrix (torch.Tensor): The input matrix. + nonzero_idx (torch.Tensor): The indices of the non-zero elements in the matrix. + Returns: + dict: A dictionary where the keys are the indices of the non-zero elements and the values are the lengths of the longest diagonal sequences starting from those indices. + """ + + visited = set() + diags = {} + for idx in nonzero_idx: + start_idx = torch.clone(idx) + tuple_start_idx = tuple(start_idx.tolist()) + + if tuple_start_idx in visited: + continue + + visited.add(tuple_start_idx) + cur_diag_len = 1 + start_idx += 1 + while start_idx[0] < input_matrix.shape[0] and start_idx[1] < input_matrix.shape[1]: + tuple_start_idx = tuple(start_idx.tolist()) + visited.add(tuple_start_idx) + + if input_matrix[start_idx[0], start_idx[1]] == 1: + cur_diag_len += 1 + start_idx += 1 + else: + break + + diags[idx] = cur_diag_len + return diags + + @staticmethod + def _get_longest_diag_index(input_matrix): + """ + Returns the start index and length of the longest diagonal in the given input. + Args: + input_matrix (numpy.ndarray): The input matrix. + Returns: + tuple: A tuple containing the start index and length of the longest diagonal. + """ + + diags = AssistedCandidateGeneratorDifferentTokenizers._get_longest_diag_dict( + input_matrix, input_matrix.nonzero() + ) + diags_values = list(diags.values()) + diags_keys = list(diags.keys()) + best_diag = np.argmax(diags_values) + diag_start_index = diags_keys[best_diag] + diag_start_length = diags_values[best_diag] + return diag_start_index, diag_start_length + + @staticmethod + def _get_tokens_diag(prompt, prompt_plus_new_tokens): + """ + Input: + prompt: 2D array of shape (batch_size, prompt_length), represents the original prompt tokens + prompt_plus_new_tokens: 2D array of shape (batch_size, prompt_length), represents the suffix of the original prompt, with additional new tokens. + Output: + discrepancy_length: int, represents the number of tokens that need to be replaced from prompt + new_tokens_only: 2D array of shape (batch_size, new_token_length), represents the new tokens that are not in prompt + discrepancy_only: 2D array of shape (batch_size, discrepancy_length), represents the new tokens that are in prompt but not in prompt_plus_new_tokens + """ + compare_mat = prompt_plus_new_tokens.T == prompt + if not torch.is_tensor(compare_mat): + compare_mat = torch.tensor(compare_mat) + + compare_mat_int = compare_mat.to(int) + + if not compare_mat_int.any().item(): + # empty intersection between prompt and prompt_plus_new_tokens + return None, None, None + + longest_location, longest_diag_length = AssistedCandidateGeneratorDifferentTokenizers._get_longest_diag_index( + compare_mat_int + ) + new_token_start_index = longest_location[0] + longest_diag_length + discrepancy_with_old = longest_location[1] + longest_diag_length + discrepancy_length = (prompt.shape[1] - discrepancy_with_old).item() + new_tokens_only = prompt_plus_new_tokens[:, new_token_start_index + discrepancy_length :] + discrepancy_only = prompt_plus_new_tokens[ + :, new_token_start_index : new_token_start_index + discrepancy_length + ] + return discrepancy_length, new_tokens_only, discrepancy_only + + def convert_source_tokens_to_target_tokens( + self, + input_ids, + source_tokenizer, + destination_tokenizer, + ): + """ + Convert token IDs from one tokenizer to another. + Args: + input_ids: The input token IDs. + source_tokenizer: The source tokenizer. + destination_tokenizer: The destination tokenizer. + Returns: + The converted token IDs. + """ + text = source_tokenizer.batch_decode(input_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True) + dest_ids = destination_tokenizer(text, add_special_tokens=True, return_tensors="pt")["input_ids"] + return dest_ids.to(input_ids.device) + + def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]: + """ + Fetches the candidates to be tried for the current input. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids) + + Return: + `torch.LongTensor` of shape `(batch_size, candidate_length)` containing the candidate sequences to be + assessed by the model and a `torch.FloatTensor` of shape `(batch_size, candidate_length, + vocabulary_size)` containing the logits associated to each candidate. + """ + max_new_tokens = int(self.num_assistant_tokens) + if max_new_tokens == 0: + return input_ids, None + + input_ids = input_ids.to(self.assistant_model.device) + convert_kwargs = { + "source_tokenizer": self.target_tokenizer, + "destination_tokenizer": self.assistant_tokenizer, + } + remove_from_pkv = 0 + + # Since re-encoding the tokens may result in tokenization discrepancies, we use 2 look behind values + # (one for each conversion) which mark where to start looking for the overlap between the + # source and target encodings, to ensure the new tokens include the correct prompt suffix. + if self.prev_tokens is not None and self.prev_target_ids.shape[1] > self.target_lookbehind: + # input_ids contains all target prompt input ids and some new target input ids + start_index_in_target_window = self.prev_target_ids.shape[1] - self.target_lookbehind + + new_assistant_ids = self.convert_source_tokens_to_target_tokens( + input_ids[:, start_index_in_target_window:], **convert_kwargs + ) + prompt_use_length = new_assistant_ids.shape[1] + prompt_use = self.prev_assistant_ids[:, -prompt_use_length:] + + discrepancy_length, new_tokens_only, discrepancy_only = ( + AssistedCandidateGeneratorDifferentTokenizers._get_tokens_diag(prompt_use, new_assistant_ids) + ) + assistant_input_ids = self.prev_assistant_ids + + if new_tokens_only is not None: + if discrepancy_length > 0 and discrepancy_only.shape[1] > 0: + if discrepancy_length == discrepancy_only.shape[1]: + assistant_input_ids[:, -discrepancy_length:] = discrepancy_only + + elif discrepancy_length > discrepancy_only.shape[1]: + discrepancy_length_diff = discrepancy_length - discrepancy_only.shape[1] + assistant_input_ids = assistant_input_ids[:, :-discrepancy_length_diff] + assistant_input_ids[:, -discrepancy_only.shape[1] :] = discrepancy_only + + remove_from_pkv = discrepancy_length + + if new_tokens_only.shape[1] > 0: + assistant_input_ids = torch.cat([assistant_input_ids, new_tokens_only], dim=-1) + else: + # edge case: in case of no intersection between prompt and new_assistant_ids + assistant_input_ids = torch.cat([assistant_input_ids, new_assistant_ids], dim=-1) + + else: + assistant_input_ids = self.convert_source_tokens_to_target_tokens(input_ids, **convert_kwargs) + self.prev_target_ids = input_ids + + self.prev_assistant_ids = assistant_input_ids + new_cur_len = assistant_input_ids.shape[-1] + min_new_tokens = max(min(max_new_tokens, self.main_model_min_length - new_cur_len), 0) + + # 1. If it is not the first round of candidate generation, prepare the inputs based on the input_ids length + # (which implicitly contains the number of accepted candidates from the previous round) + has_past_key_values = self.assistant_kwargs.get("past_key_values", None) is not None + if has_past_key_values: + new_cache_size = new_cur_len - 1 - remove_from_pkv + self.assistant_kwargs["past_key_values"] = _crop_past_key_values( + self.assistant_model, self.assistant_kwargs["past_key_values"], new_cache_size - 1 + ) # the assistant does not have the token after the last match, hence the -1 + + self.assistant_kwargs = _prepare_attention_mask( + self.assistant_kwargs, new_cur_len, self.assistant_model.config.is_encoder_decoder + ) + self.assistant_kwargs = _prepare_token_type_ids(self.assistant_kwargs, new_cur_len) + + # 2. Forecast next N tokens using the assistant model. + assistant_generation_kwargs = { + self.input_ids_key: assistant_input_ids, + "min_new_tokens": min_new_tokens, + "max_new_tokens": max_new_tokens, + "generation_config": self.generation_config, + "logits_processor": self.logits_processor, + } + + self.assistant_kwargs.pop("attention_mask", None) + + assistant_output = self.assistant_model.generate(**assistant_generation_kwargs, **self.assistant_kwargs) + + num_prev_assistant = self.prev_assistant_ids.shape[1] + start_assistant_look_index = num_prev_assistant - self.assistant_lookbehind + + new_target_ids_from_window = self.convert_source_tokens_to_target_tokens( + assistant_output.sequences[:, start_assistant_look_index:], + source_tokenizer=self.assistant_tokenizer, + destination_tokenizer=self.target_tokenizer, + ) + target_prompt_use_length = new_target_ids_from_window.shape[1] + + target_prompt_use = input_ids[:, -target_prompt_use_length:] + + _, target_new_tokens_only, _ = AssistedCandidateGeneratorDifferentTokenizers._get_tokens_diag( + target_prompt_use, new_target_ids_from_window + ) + + new_target_ids = input_ids + + if target_new_tokens_only is not None: + if target_new_tokens_only.shape[1] > 0: + new_target_ids = torch.cat([new_target_ids, target_new_tokens_only], dim=-1) + else: + # edge case: in case of no intersection between prompt and new_target_ids + new_target_ids = torch.cat([new_target_ids, new_target_ids_from_window], dim=-1) + + self.prev_target_ids = input_ids + + if hasattr(self.generation_config, "max_length"): + new_target_ids = new_target_ids[:, : self.generation_config.max_length] + + # 3. Update variables for the next round of candidate generation + self.assistant_kwargs["past_key_values"] = assistant_output.past_key_values + self.prev_tokens = assistant_output.sequences + + # 4. Prepare variables for output + if input_ids.shape[1] >= new_target_ids.shape[1]: + return input_ids, None + + return new_target_ids, None + + class PromptLookupCandidateGenerator(CandidateGenerator): """ `CandidateGenerator` class to be used for prompt lookup generation. This class generates candidates by looking up diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 35ca292d9f..b355bbeaa9 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -51,6 +51,7 @@ from .beam_constraints import DisjunctiveConstraint, PhrasalConstraint from .beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer from .candidate_generator import ( AssistedCandidateGenerator, + AssistedCandidateGeneratorDifferentTokenizers, CandidateGenerator, PromptLookupCandidateGenerator, _crop_past_key_values, @@ -617,7 +618,7 @@ class GenerationMixin: model_input_name = model_input_name if model_input_name is not None else self.main_input_name encoder_kwargs["return_dict"] = True encoder_kwargs[model_input_name] = inputs_tensor - model_kwargs["encoder_outputs"]: ModelOutput = encoder(**encoder_kwargs) + model_kwargs["encoder_outputs"]: ModelOutput = encoder(**encoder_kwargs) # type: ignore return model_kwargs @@ -787,11 +788,15 @@ class GenerationMixin: inputs_tensor: torch.Tensor, assistant_model: "PreTrainedModel", logits_processor: LogitsProcessorList, + target_tokenizer: "PreTrainedTokenizerBase", + assistant_tokenizer: "PreTrainedTokenizerBase", model_kwargs: Dict, ) -> CandidateGenerator: """ Returns the candidate generator to be used in `assisted_generation` """ + different_tokenizers = all(v is not None for v in (assistant_model, target_tokenizer, assistant_tokenizer)) + if generation_config.prompt_lookup_num_tokens is not None: candidate_generator = PromptLookupCandidateGenerator( eos_token_id=generation_config._eos_token_tensor, @@ -799,6 +804,17 @@ class GenerationMixin: max_matching_ngram_size=generation_config.max_matching_ngram_size, max_length=generation_config.max_length, ) + elif different_tokenizers: + candidate_generator = AssistedCandidateGeneratorDifferentTokenizers( + input_ids=input_ids, + assistant_model=assistant_model, + generation_config=generation_config, + model_kwargs=model_kwargs, + inputs_tensor=inputs_tensor, + logits_processor=logits_processor, + target_tokenizer=target_tokenizer, + assistant_tokenizer=assistant_tokenizer, + ) else: candidate_generator = AssistedCandidateGenerator( input_ids=input_ids, @@ -1250,7 +1266,7 @@ class GenerationMixin: f"names: {terminations_with_generation_support}." ) - def _validate_assistant(self, assistant_model): + def _validate_assistant(self, assistant_model, tokenizer, assistant_tokenizer): if assistant_model is None: return @@ -1266,8 +1282,19 @@ class GenerationMixin: "Ensure you load the assistant with the correct encoder-decoder class, e.g. `AutoModelForSpeechSeq2Seq` for Whisper." ) - if not self.config.get_text_config().vocab_size == assistant_model.config.get_text_config().vocab_size: - raise ValueError("Make sure the main and assistant model use the same tokenizer") + doc_reference = ( + "(see https://huggingface.co/docs/transformers/en/generation_strategies#universal-assisted-decoding)" + ) + if self.config.get_text_config().vocab_size == assistant_model.config.get_text_config().vocab_size: + if assistant_tokenizer is not None: + raise ValueError( + f"`assistant_tokenizer` is not required when the main and assistant models use the same tokenizer. Please omit `assistant_tokenizer` from `generate()` {doc_reference}." + ) + else: + if tokenizer is None or assistant_tokenizer is None: + raise ValueError( + f"The main and assistant moedels have different tokenizers. Please provide `tokenizer` and `assistant_tokenizer` to `generate()` {doc_reference}." + ) def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]): """Validates model kwargs for generation. Generate argument typos will also be caught here.""" @@ -1923,12 +1950,15 @@ class GenerationMixin: - [`~generation.GenerateEncoderDecoderOutput`], - [`~generation.GenerateBeamEncoderDecoderOutput`] """ + # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call self._validate_model_class() tokenizer = kwargs.pop("tokenizer", None) # Pull this out first, we only use it for stopping criteria + assistant_tokenizer = kwargs.pop("assistant_tokenizer", None) # only used for assisted generation + generation_config, model_kwargs = self._prepare_generation_config(generation_config, **kwargs) self._validate_model_kwargs(model_kwargs.copy()) - self._validate_assistant(assistant_model) + self._validate_assistant(assistant_model, tokenizer, assistant_tokenizer) # 2. Set generation parameters if not already defined if synced_gpus is None: @@ -2110,6 +2140,8 @@ class GenerationMixin: inputs_tensor=inputs_tensor, assistant_model=assistant_model, logits_processor=logits_processor, + target_tokenizer=tokenizer, + assistant_tokenizer=assistant_tokenizer, model_kwargs=model_kwargs, ) @@ -4138,7 +4170,7 @@ class GenerationMixin: # 1. Fetch candidate sequences from a `CandidateGenerator` candidate_input_ids, candidate_logits = candidate_generator.get_candidates(input_ids) - candidate_input_ids = candidate_input_ids.to(self.device) + if candidate_logits is not None: candidate_logits = candidate_logits.to(self.device) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 58259821cf..1727aed111 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -88,6 +88,7 @@ if is_torch_available(): WatermarkDetector, WatermarkingConfig, ) + from transformers.generation.candidate_generator import AssistedCandidateGeneratorDifferentTokenizers from transformers.generation.utils import _speculative_sampling @@ -3510,6 +3511,34 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi self.assertTrue(test_bos_id == gen_output[0, 0]) self.assertTrue(generation_config.bos_token_id is None) + def test_speculative_decoding_equals_regular_decoding(self): + draft_name = "double7/vicuna-68m" + target_name = "Qwen/Qwen2-0.5B-Instruct" + + draft_model = AutoModelForCausalLM.from_pretrained(draft_name) + target_model = AutoModelForCausalLM.from_pretrained(target_name) + + assistant_tokenizer = AutoTokenizer.from_pretrained(draft_name) + target_tokenizer = AutoTokenizer.from_pretrained(target_name) + + prompt_size = torch.randint(low=20, high=100, size=(1,)) + max_new_tokens = torch.randint(low=10, high=50, size=(1,)) + input_ids = (torch.rand(1, prompt_size[0]) * 100).to(int) + 50 + + max_new_tokens_item = max_new_tokens[0].item() + expected_out = target_model.generate(input_ids, do_sample=False, max_new_tokens=max_new_tokens_item) + predicted_out = target_model.generate( + input_ids, + do_sample=False, + max_new_tokens=max_new_tokens_item, + assistant_model=draft_model, + target_tokenizer=target_tokenizer, + assistant_tokenizer=assistant_tokenizer, + ) + + self.assertEqual(expected_out.shape, predicted_out.shape) + self.assertTrue((expected_out == predicted_out).all().item()) + @pytest.mark.generate @require_torch_multi_gpu def test_generate_with_static_cache_multi_gpu(self): @@ -3884,3 +3913,41 @@ class TokenHealingTestCase(unittest.TestCase): # bos_token_id is required when no input ids nor inputs_embeds is passed with self.assertRaises(ValueError): model.generate(max_length=20, bos_token_id=None) + + +class TestAssistedCandidateGeneratorDifferentTokenizers(unittest.TestCase): + def test_no_intersection(self): + prompt = np.array([[1, 2, 3]]) + prompt_plus_new_tokens = np.array([[4, 5, 6]]) + result = AssistedCandidateGeneratorDifferentTokenizers._get_tokens_diag(prompt, prompt_plus_new_tokens) + self.assertEqual(result, (None, None, None)) + + def test_complete_overlap(self): + prompt = np.array([[1, 2, 3]]) + prompt_plus_new_tokens = np.array([[1, 2, 3, 4, 5]]) + discrep_length, new_tokens_only, discrep_only = AssistedCandidateGeneratorDifferentTokenizers._get_tokens_diag( + prompt, prompt_plus_new_tokens + ) + self.assertEqual(discrep_length, 0) + np.testing.assert_array_equal(new_tokens_only, np.array([[4, 5]])) + np.testing.assert_array_equal(discrep_only, np.array([[]])) + + def test_partial_overlap(self): + prompt = np.array([[1, 2, 3]]) + prompt_plus_new_tokens = np.array([[2, 3, 4, 5]]) + discrep_length, new_tokens_only, discrep_only = AssistedCandidateGeneratorDifferentTokenizers._get_tokens_diag( + prompt, prompt_plus_new_tokens + ) + self.assertEqual(discrep_length, 0) + np.testing.assert_array_equal(new_tokens_only, np.array([[4, 5]])) + np.testing.assert_array_equal(discrep_only, np.array([[]])) + + def test_no_new_tokens(self): + prompt = np.array([[1, 2, 3]]) + prompt_plus_new_tokens = np.array([[1, 2, 3]]) + discrep_length, new_tokens_only, discrep_only = AssistedCandidateGeneratorDifferentTokenizers._get_tokens_diag( + prompt, prompt_plus_new_tokens + ) + self.assertEqual(discrep_length, 0) + np.testing.assert_array_equal(new_tokens_only, np.array([[]])) + np.testing.assert_array_equal(discrep_only, np.array([[]]))