Universal Speculative Decoding CandidateGenerator (#35029)

* move `TestAssistedCandidateGeneratorDifferentTokenizers` into a new testing file

* refactor

* NOTHING. add space to rerun github actions tests

* remove it...

* `UniversalSpeculativeDecodingGenerator`

* Use `UniversalSpeculativeDecodingGenerator` when `generation_config.do_sample=True`

* assistant tokenizes only the target's new suffix

* formatting

* fix code

* fix code

* formatting

* add `TestGenerateWithDifferentModels`

* `TestGenerateWithDifferentModels` parameterize on `do_sample`

* `AssistantVocabMapping` & `AssistantVocabMappingCache`

* formatting

* `AssistantToTargetTranslator`: `get_target_input_ids` & `get_target_logits`

* improve `_get_assistant_to_target_input_ids` & formatting

* renaming

* WIP: debugging `min_new_tokens`

* fix get_target_ids

* `UniversalSpeculativeDecodingGenerator`

* assistant tokenizes only the target's new suffix

* formatting

* fix code

* fix code

* formatting

* `TestGenerateWithDifferentModels` parameterize on `do_sample`

* `AssistantVocabMapping` & `AssistantVocabMappingCache`

* formatting

* `AssistantToTargetTranslator`: `get_target_input_ids` & `get_target_logits`

* improve `_get_assistant_to_target_input_ids` & formatting

* renaming

* WIP: debugging `min_new_tokens`

* fix get_target_ids

* fix device issue

* fix get_assistant_input_ids

* add `TestAssistedCandidateGeneratorDifferentTokenizers`

* formatting

* `AssistantVocabTranslatorCache` refactor & tests

* revert changes in `src/transformers/generation/logits_process.py`

* refactor `AssistedCandidateGenerator`

* refactor `AssistedCandidateGeneratorDifferentTokenizers`

* formatting

* refactor `UniversalSpeculativeDecodingGenerator`

* fix negative value for max_new_tokens

* fix generation length target + attention_mask vs. assistant + attent

* fix device

* fix negative max_new_tokens bug

* fix UAG

* minor

* formatting

* `AssistedCandidateGeneratorDifferentTokenizers` `lookbehind`s init

* resolve conflict & formatting

* rerun CI tests

* remove space...

* remove old code

* fix candidate_input_ids device

* minor

* formatting

* Fix prepare + apply (#7)

* fix prepare + apply

* move to cpu

* simplity suppress_tokens

* fix bugs and refacatoring

* device move

* handle self.config.vocab_size > len(target_tokenizer.get_vocab())

* no need to normalize in candidate_generator

* address Nadav's comments + minor

* optimize device move + SuppressTokensLogitsProcessor

* AssistantToTargetTranslator, SuppressTokensLogitsProcessor and tokenizers mapping improvements

* padding size

* padding improvement

* fix and simplify get_target_logits

* renaming in get_target_logits

* minor

* add filter_value and suppress_tokens_id

* style + rename

* remove TODO

* restore original SelectTokensLogitsProcessor with modification

* fix style

* fix _update_past_and_masks and optimize code

* remove assistant_vocab_size arg

* fix attention_mask

* call _prepare_attention_mask also if not has_past_key_values

* handling attention mask for first generation

* comment

* restore test

* remove SelectTokensLogitsProcessor

* _update_past_and_masks implementation for USD

* Add unittests for Universal Assisted generation

* fix style

* update tests

* Remove unused import and fix `test_speculation_depth` test

* exclude special and reserved tokens from tokenizer for UAG

* mv `test_universal_assisted_generation.py` to `generation/test_candidate_generator.py`

* Remove unused imports and fix style using `make style` (#9)

* formatting

* Swap gated `meta-llama/llama-3.2` with `allenai/llama` (#10)

* Fix space sign disagreement (#12)

* default values for AssistantToTargetTranslator fileds

* fix space sign

* minor

* fix test + style

* Default values for some fields of assistant to target translator (#11)

* default values for AssistantToTargetTranslator fileds

* fix

* add support to empty logit_processors

* Update candidate_generator.py (#15)

fix typo

* BUG fix in _prepare_assistant_input_ids (#14)

* fix _prepare_assistant_input_ids

* target_to_assistant_input_ids

* Update src/transformers/generation/candidate_generator.py

Co-authored-by: Nadav Timor <nadav.timor@weizmann.ac.il>

---------

Co-authored-by: Nadav Timor <nadav.timor@weizmann.ac.il>

* typo (`target_to_assistant_input_ids`)

* formatting

* merge upstream/main

* Fix minor review comments (#16)

* Fix: `token_ids.to(torch.int64)` (#18)

* tok ids to `torch.int64` (reference: https://huggingface.co/docs/transformers.js/en/api/tokenizers)

* `LongTensor`

* fix dtype

* `assistant_input_ids.to(dtype=torch.long)`

* Remove unused import from test_candidate_generator.py

* Remove unused import from test_candidate_generator.py

* Remove `numpy` import

* resolve pr comments (#19)

* `AssistantToTargetTranslator` docstring

* (per gante's comment) `filter_value` and `suppress_tokens_id` to class constants

* update `AssistantToTargetTranslator` docstring

* (gante's comment) replace `match-case`

* formatting

* Fix Joao's comments (#21)

* remove threading

* fix logits_processor

* fix test device

* fix style (#23)

* Move atm (#24)

* move AssistantToTargetTranslator

* fixup

* fix logit_processor

* add atm_translator test

* refactor test

* remove threading from test

* add require_torch in tests

* move AssistantVocabTranslatorCache + add tests

* ruff fix

---------

Co-authored-by: jmamou <jonathan.mamou@intel.com>
Co-authored-by: Gaurav <gauravj@d-matrix.ai>
Co-authored-by: Gaurav Jain <gaurjain14@gmail.com>
Co-authored-by: gauravjain14 <41287729+gauravjain14@users.noreply.github.com>
This commit is contained in:
Nadav Timor
2025-02-26 11:14:02 -05:00
committed by GitHub
parent 082834dd79
commit d18d9c3205
3 changed files with 639 additions and 47 deletions

View File

@@ -14,6 +14,7 @@
# limitations under the License.
import copy
import weakref
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
import numpy as np
@@ -27,7 +28,7 @@ if is_sklearn_available():
from ..cache_utils import DynamicCache
from ..pytorch_utils import isin_mps_friendly
from .logits_process import LogitsProcessorList, MinLengthLogitsProcessor
from .logits_process import LogitsProcessorList, MinLengthLogitsProcessor, SuppressTokensLogitsProcessor
if TYPE_CHECKING:
@@ -283,18 +284,21 @@ class AssistedCandidateGenerator(CandidateGenerator):
min_new_tokens = max(min(max_new_tokens, self.main_model_min_length - new_cur_len), 0)
return min_new_tokens, max_new_tokens
def _update_past_and_masks(self, input_ids: torch.LongTensor, remove_from_pkv: int = 0) -> bool:
def _update_past_and_masks(
self, input_ids: torch.LongTensor, remove_from_pkv: int = 0, num_added_tokens: int = 1
) -> bool:
"""Update past key values and attention masks for subsequent generation rounds."""
has_past_key_values = self.assistant_kwargs.get("past_key_values", None) is not None
if has_past_key_values:
new_cache_size = input_ids.shape[-1] - 1 - remove_from_pkv
self.assistant_kwargs["past_key_values"] = _crop_past_key_values(
self.assistant_model, self.assistant_kwargs["past_key_values"], new_cache_size - 1
self.assistant_model, self.assistant_kwargs["past_key_values"], new_cache_size - num_added_tokens
)
self.assistant_kwargs = _prepare_attention_mask(
self.assistant_kwargs, input_ids.shape[-1], self.assistant_model.config.is_encoder_decoder
)
self.assistant_kwargs = _prepare_token_type_ids(self.assistant_kwargs, input_ids.shape[-1])
return has_past_key_values
def _prepare_generation_args(self, input_ids: torch.LongTensor, min_new_tokens: int, max_new_tokens: int) -> Dict:
@@ -608,6 +612,290 @@ class AssistedCandidateGeneratorDifferentTokenizers(AssistedCandidateGenerator):
return new_target_ids
class AssistantToTargetTranslator:
"""
Translates token ids and logits between assistant and target model vocabularies. This class is used to handle
vocabulary mismatches when using different tokenizers for the assistant and target models in speculative decoding,
as introduced in the paper "Lossless Speculative Decoding Algorithms for Heterogeneous Vocabularies"
(https://www.arxiv.org/abs/2502.05202).
It maintains mappings between the two vocabularies and handles token/logit conversion.
Args:
target_tokenizer (`PreTrainedTokenizerBase`):
The tokenizer used by the target (main) model.
assistant_tokenizer (`PreTrainedTokenizerBase`):
The tokenizer used by the assistant model.
assistant_model_device (`str`, defaults to "cpu"):
The device where the assistant model is located. Used for placing tensors.
target_vocab_size (`int`, *optional*):
The size of the target model's vocabulary. If not provided, will be inferred from the target tokenizer.
"""
FILTER_VALUE: float = -float("Inf") # The value used to filter out unmapped tokens in the logits.
SUPPRESS_TOKEN_ID: int = -1 # The ID used to mark suppressed tokens in the mapping.
def __init__(
self,
target_tokenizer: "PreTrainedTokenizerBase",
assistant_tokenizer: "PreTrainedTokenizerBase",
target_vocab_size: int, # required since target_vocab_size can be different from the length of target_tokenizer.get_vocab()
assistant_model_device: str = "cpu",
):
self._target_tokenizer: "PreTrainedTokenizerBase" = target_tokenizer
self._assistant_tokenizer: "PreTrainedTokenizerBase" = assistant_tokenizer
self._assistant_model_device: str = assistant_model_device
self.target_vocab_size: int = target_vocab_size
self._assistant_to_target_input_ids, self.target_to_assistant_input_ids = (
self._get_assistant_to_target_input_ids()
)
self._suppress_input_ids: list[int] = self._get_suppress_input_ids()
self.logits_processors: Optional[LogitsProcessorList] = None
if len(self._suppress_input_ids) > 0:
# len(self._suppress_input_ids) = 0 if the assistant vocab is a subset of the target vocab
self.logits_processors = LogitsProcessorList(
[SuppressTokensLogitsProcessor(self._get_suppress_input_ids(), self._assistant_model_device)]
)
def _get_assistant_to_target_input_ids(self):
target_vocab = self._target_tokenizer.get_vocab()
assistant_vocab = self._assistant_tokenizer.get_vocab()
space_str = " "
target_space_ids = self._target_tokenizer(space_str, add_special_tokens=False)["input_ids"]
if len(target_space_ids) > 0:
target_space_sign = self._target_tokenizer.convert_ids_to_tokens(target_space_ids)[0][0]
assistant_space_ids = self._assistant_tokenizer(space_str, add_special_tokens=False)["input_ids"]
if len(assistant_space_ids) > 0:
assistant_space_sign = self._assistant_tokenizer.convert_ids_to_tokens(assistant_space_ids)[0][0]
if target_space_sign != assistant_space_sign:
# If the assistant tokenizer has a different space sign than the target tokenizer,
# we need to replace the assistant space sign with the target space sign in the assistant_vocab.
assistant_vocab = {
(
tok.replace(assistant_space_sign, target_space_sign, 1)
if tok.startswith(assistant_space_sign)
else tok
): idx
for tok, idx in assistant_vocab.items()
}
max_assistant_index = max(assistant_vocab.values())
assistant_to_target_input_ids = torch.full((max_assistant_index + 1,), self.SUPPRESS_TOKEN_ID, dtype=int)
target_to_assistant_input_ids: Dict[int, int] = {}
for tok, assistant_id in assistant_vocab.items():
target_id = target_vocab.get(tok)
if target_id is not None:
assistant_to_target_input_ids[assistant_id] = target_id
target_to_assistant_input_ids[target_id] = assistant_id
return assistant_to_target_input_ids.to(self._assistant_model_device), target_to_assistant_input_ids
def _get_suppress_input_ids(self) -> list[int]:
"""
Get the input ids that are in the assistant vocab but not in the target vocab.
"""
return torch.where(self._assistant_to_target_input_ids == self.SUPPRESS_TOKEN_ID)[0]
def get_target_ids(
self, assistant_input_ids, target_input_ids, assistant_candidate_ids: torch.LongTensor
) -> torch.LongTensor:
"""
Return the target candidate ids that correspond to the assistant candidate ids.
Note that we have already the target ids for the prompt and we only need to find the target ids for the new tokens.
Moreover, assistant ids of the original prompt does not necessarily appear in _assistant_to_target_input_ids.
"""
num_new_tokens = len(assistant_candidate_ids[0]) - assistant_input_ids.shape[1]
if num_new_tokens == 0:
return target_input_ids
else:
transformed_slice = self._assistant_to_target_input_ids[assistant_candidate_ids[0, -num_new_tokens:]]
return torch.cat((target_input_ids, transformed_slice.unsqueeze(0)), dim=1)
def get_target_logits(self, assistant_logits: torch.FloatTensor) -> torch.FloatTensor:
"""
Return the target logits that correspond to the assistant logits.
"""
target_shape: tuple[int, ...] = (*assistant_logits.shape[:-1], self.target_vocab_size)
target_logits: torch.FloatTensor = torch.full(target_shape, self.FILTER_VALUE).to(self._assistant_model_device)
# Mask for valid indices
assistant_indices_mask = self._assistant_to_target_input_ids != self.SUPPRESS_TOKEN_ID
# Exclude invalid indices
target_logits_supported_indices = self._assistant_to_target_input_ids[assistant_indices_mask]
valid_assistant_logits = assistant_logits[..., : self._assistant_to_target_input_ids.shape[0]]
target_logits[..., target_logits_supported_indices] = valid_assistant_logits[..., assistant_indices_mask]
return target_logits
class AssistantVocabTranslatorCache:
"""
Cache for `AssistantToTargetTranslator` instances. The instances are computed at
pre-processing time, and this cache allows us to avoid recomputing them.
"""
_cache = weakref.WeakKeyDictionary()
@classmethod
def get_translator(
cls,
target_tokenizer: "PreTrainedTokenizerBase",
assistant_tokenizer: "PreTrainedTokenizerBase",
target_vocab_size: int,
assistant_model_device: str = "cpu",
) -> AssistantToTargetTranslator:
assistant_dict = cls._cache.get(target_tokenizer)
if assistant_dict is None:
assistant_dict = weakref.WeakKeyDictionary()
cls._cache[target_tokenizer] = assistant_dict
mapping = assistant_dict.get(assistant_tokenizer)
if mapping is None:
mapping = AssistantToTargetTranslator(
target_tokenizer, assistant_tokenizer, target_vocab_size, assistant_model_device
)
assistant_dict[assistant_tokenizer] = mapping
return mapping
@classmethod
def cleanup(cls):
"""
Clean up dead references in the cache.
This removes entries where either the target_tokenizer or assistant_tokenizer
has been garbage collected.
"""
# Remove entries from the outer cache where the target_tokenizer is no longer alive
dead_keys = [key for key in cls._cache if key is None]
for key in dead_keys:
del cls._cache[key]
# For each assistant_dict, remove entries where assistant_tokenizer is no longer alive
for assistant_dict in cls._cache.values():
dead_keys = [key for key in assistant_dict if key is None]
for key in dead_keys:
del assistant_dict[key]
class UniversalSpeculativeDecodingGenerator(AssistedCandidateGeneratorDifferentTokenizers):
"""
`CandidateGenerator` class to be used for Universal Speculative Decoding (USD): speculative decoding with different tokenizers
for the assistant and main models. This class generates candidates through the use of a smaller model.
"""
def __init__(
self,
input_ids: torch.LongTensor,
assistant_model: "PreTrainedModel",
target_tokenizer: "PreTrainedTokenizerBase",
assistant_tokenizer: "PreTrainedTokenizerBase",
generation_config: "GenerationConfig",
model_kwargs: Dict,
atm_translator: AssistantToTargetTranslator,
inputs_tensor: Optional[torch.Tensor] = None,
logits_processor: "LogitsProcessorList" = None,
):
# Initialize translator before parent class
self._atm_translator = atm_translator
super().__init__(
input_ids,
assistant_model,
target_tokenizer,
assistant_tokenizer,
generation_config,
model_kwargs,
inputs_tensor,
logits_processor,
)
# Track sequence lengths and previous assistant IDs
self._target_seq_len_with_candidates: int = 0
self._prev_assistant_ids: Optional[torch.LongTensor] = None
def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]:
"""
Simplified version of get_candidates that uses the translator cache for token conversion.
"""
target_input_ids = input_ids.to(self.assistant_model.device)
assistant_input_ids, num_added_tokens = self._prepare_assistant_input_ids(target_input_ids)
min_new_tokens, max_new_tokens = self._calculate_new_tokens(target_input_ids)
if max_new_tokens == 0:
return input_ids, None
self._update_past_and_masks(assistant_input_ids, num_added_tokens=num_added_tokens)
generation_args = self._prepare_generation_args(assistant_input_ids, min_new_tokens, max_new_tokens)
# Ensure scores are returned
generation_args["generation_config"].output_scores = True
generation_args["generation_config"].return_dict_in_generate = True
# Generate and process outputs using translator
if self._atm_translator.logits_processors is not None:
generation_args["logits_processor"] = self._atm_translator.logits_processors
self._prev_assistant_ids, assistant_candidate_logits = self._generate_candidates(generation_args)
# Use translator to convert tokens and logits
target_candidate_ids = self._atm_translator.get_target_ids(
assistant_input_ids, target_input_ids, self._prev_assistant_ids
)
self._target_seq_len_with_candidates = target_candidate_ids.shape[-1]
target_candidate_logits = self._atm_translator.get_target_logits(assistant_candidate_logits)
return target_candidate_ids, target_candidate_logits
def _update_past_and_masks(self, assistant_input_ids: torch.LongTensor, num_added_tokens: int = 1) -> bool:
if self._prev_assistant_ids is None:
# Prepare attention mask for the first generation.
# For subsequent generations, the attention mask is updated in super()_update_past_and_masks.
self.assistant_kwargs = _prepare_attention_mask(
self.assistant_kwargs, assistant_input_ids.shape[-1], self.assistant_model.config.is_encoder_decoder
)
return super()._update_past_and_masks(assistant_input_ids, num_added_tokens=num_added_tokens)
def _prepare_assistant_input_ids(self, target_input_ids: torch.LongTensor) -> torch.LongTensor:
"""
Simplified token conversion that only processes new tokens.
"""
# Calculate new tokens since last call
target_seq_len = target_input_ids.shape[-1]
if self._target_seq_len_with_candidates == 0:
new_token_count = target_seq_len
else:
new_token_count = 1
target_new_ids = target_input_ids[:, -new_token_count:]
# Convert the new tokens
assistant_new_ids = None
if self._target_seq_len_with_candidates > 0:
# we have only one new token and we can directly convert it
assistant_new_ids = self._atm_translator.target_to_assistant_input_ids.get(target_new_ids[0].item())
if assistant_new_ids is None:
target_new_text = self.target_tokenizer.batch_decode(
target_new_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
)
assistant_new_ids = self.assistant_tokenizer(
target_new_text, add_special_tokens=False, return_tensors="pt"
)["input_ids"].to(self.assistant_model.device)
else:
assistant_new_ids = torch.tensor([[assistant_new_ids]], device=self.assistant_model.device)
# Update or initialize assistant IDs
if self._prev_assistant_ids is None:
assistant_input_ids = assistant_new_ids
else:
tokens_to_remove = self._target_seq_len_with_candidates + 1 - target_seq_len
# If the number of new tokens is greater than zero, truncate the previous assistant IDs
if tokens_to_remove > 0:
self._prev_assistant_ids = self._prev_assistant_ids[:, :-tokens_to_remove]
assistant_input_ids = torch.cat([self._prev_assistant_ids, assistant_new_ids], dim=-1)
assistant_input_ids = assistant_input_ids.to(dtype=torch.long)
return assistant_input_ids, len(assistant_new_ids[0])
class PromptLookupCandidateGenerator(CandidateGenerator):
"""
`CandidateGenerator` class to be used for prompt lookup generation. This class generates candidates by looking up

View File

@@ -26,6 +26,8 @@ import torch.distributed as dist
from torch import nn
from torch.nn import functional as F
from transformers.generation.candidate_generator import AssistantVocabTranslatorCache
from ..cache_utils import (
Cache,
DynamicCache,
@@ -56,6 +58,7 @@ from .candidate_generator import (
CandidateGenerator,
EarlyExitCandidateGenerator,
PromptLookupCandidateGenerator,
UniversalSpeculativeDecodingGenerator,
_crop_past_key_values,
_prepare_attention_mask,
_prepare_token_type_ids,
@@ -858,6 +861,22 @@ class GenerationMixin:
max_length=generation_config.max_length,
)
elif different_tokenizers:
if generation_config.do_sample is True:
atm_translator = AssistantVocabTranslatorCache.get_translator(
target_tokenizer, assistant_tokenizer, self.config.vocab_size, assistant_model.device
)
candidate_generator = UniversalSpeculativeDecodingGenerator(
input_ids=input_ids,
assistant_model=assistant_model,
generation_config=generation_config,
model_kwargs=model_kwargs,
inputs_tensor=inputs_tensor,
logits_processor=logits_processor,
target_tokenizer=target_tokenizer,
assistant_tokenizer=assistant_tokenizer,
atm_translator=atm_translator,
)
elif generation_config.do_sample is False:
candidate_generator = AssistedCandidateGeneratorDifferentTokenizers(
input_ids=input_ids,
assistant_model=assistant_model,
@@ -868,6 +887,10 @@ class GenerationMixin:
target_tokenizer=target_tokenizer,
assistant_tokenizer=assistant_tokenizer,
)
else:
raise ValueError(
f"Invalid value for `do_sample`: expected a boolean, got {type(generation_config.do_sample).__name__}"
)
else:
candidate_generator = AssistedCandidateGenerator(
input_ids=input_ids,
@@ -4225,7 +4248,6 @@ class GenerationMixin:
# 1. Fetch candidate sequences from a `CandidateGenerator` and move to the correct device
candidate_input_ids, candidate_logits = candidate_generator.get_candidates(input_ids)
candidate_input_ids = candidate_input_ids.to(self.device)
if candidate_logits is not None:
candidate_logits = candidate_logits.to(self.device)

View File

@@ -1,43 +1,325 @@
import gc
import unittest
import weakref
from unittest.mock import MagicMock
import numpy as np
import torch
from transformers.generation.candidate_generator import AssistedCandidateGeneratorDifferentTokenizers
class TestAssistedCandidateGeneratorDifferentTokenizers(unittest.TestCase):
def test_no_intersection(self):
prompt = np.array([[1, 2, 3]])
prompt_plus_new_tokens = np.array([[4, 5, 6]])
result = AssistedCandidateGeneratorDifferentTokenizers._get_tokens_diag(prompt, prompt_plus_new_tokens)
self.assertEqual(result, (None, None, None))
def test_complete_overlap(self):
prompt = np.array([[1, 2, 3]])
prompt_plus_new_tokens = np.array([[1, 2, 3, 4, 5]])
discrep_length, new_tokens_only, discrep_only = AssistedCandidateGeneratorDifferentTokenizers._get_tokens_diag(
prompt, prompt_plus_new_tokens
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, GenerationConfig, pipeline
from transformers.generation.candidate_generator import (
AssistantToTargetTranslator,
AssistantVocabTranslatorCache,
UniversalSpeculativeDecodingGenerator,
)
self.assertEqual(discrep_length, 0)
np.testing.assert_array_equal(new_tokens_only, np.array([[4, 5]]))
np.testing.assert_array_equal(discrep_only, np.array([[]]))
from transformers.testing_utils import require_torch, torch_device
def test_partial_overlap(self):
prompt = np.array([[1, 2, 3]])
prompt_plus_new_tokens = np.array([[2, 3, 4, 5]])
discrep_length, new_tokens_only, discrep_only = AssistedCandidateGeneratorDifferentTokenizers._get_tokens_diag(
prompt, prompt_plus_new_tokens
)
self.assertEqual(discrep_length, 0)
np.testing.assert_array_equal(new_tokens_only, np.array([[4, 5]]))
np.testing.assert_array_equal(discrep_only, np.array([[]]))
def test_no_new_tokens(self):
prompt = np.array([[1, 2, 3]])
prompt_plus_new_tokens = np.array([[1, 2, 3]])
discrep_length, new_tokens_only, discrep_only = AssistedCandidateGeneratorDifferentTokenizers._get_tokens_diag(
prompt, prompt_plus_new_tokens
@require_torch
class TestAssistantToTargetTranslator(unittest.TestCase):
def setUp(self):
# Create mock tokenizers with predefined vocabularies
self.target_tokenizer = MagicMock()
self.assistant_tokenizer = MagicMock()
# Define mock vocabularies for the tokenizers
self.target_vocab = {"hello": 0, "world": 1, "foo": 2, "bar": 3}
self.assistant_vocab = {"hello": 0, "world": 1, "foo": 2, "baz": 4}
self.target_tokenizer.get_vocab.return_value = self.target_vocab
self.assistant_tokenizer.get_vocab.return_value = self.assistant_vocab
self.assistant_model_device = torch_device
self.target_vocab_size = 6
# Instantiate the class under test
self.translator = AssistantToTargetTranslator(
target_tokenizer=self.target_tokenizer,
assistant_tokenizer=self.assistant_tokenizer,
assistant_model_device=self.assistant_model_device,
target_vocab_size=self.target_vocab_size,
)
self.assertEqual(discrep_length, 0)
np.testing.assert_array_equal(new_tokens_only, np.array([[]]))
np.testing.assert_array_equal(discrep_only, np.array([[]]))
def test_get_assistant_to_target_input_ids(self):
"""Test the mapping from assistant tokens to target tokens."""
expected_mapping = [0, 1, 2, self.translator.SUPPRESS_TOKEN_ID, self.translator.SUPPRESS_TOKEN_ID]
actual_mapping = self.translator._assistant_to_target_input_ids.tolist()
self.assertEqual(actual_mapping, expected_mapping)
def test_get_suppress_input_ids(self):
"""Test the suppression of assistant input IDs not present in the target vocabulary."""
expected_suppress_ids = [3, 4]
actual_suppress_ids = self.translator._get_suppress_input_ids().tolist()
self.assertEqual(actual_suppress_ids, expected_suppress_ids)
def test_get_target_ids(self):
"""Test the translation of assistant candidate IDs to target candidate IDs."""
assistant_input_ids = torch.LongTensor([[0, 1, 2]]).to(
self.assistant_model_device
) # 'hello world foo' in assistant tokenizer
target_input_ids = torch.LongTensor([[0, 1, 2]]).to(
self.assistant_model_device
) # 'hello world foo' in target tokenizer
assistant_candidate_ids = torch.LongTensor([[0, 1, 2, 4]]).to(
self.assistant_model_device
) # 'hello world foo baz' in assistant tokenizer
expected_target_ids = torch.LongTensor(
[[0, 1, 2, self.translator.SUPPRESS_TOKEN_ID]]
).to(
self.assistant_model_device
) # 'hello world foo baz' in target tokenizer (baz is mapped to self.translator.suppress_tokens_id since it does not exist in target vocab)
actual_target_ids = self.translator.get_target_ids(
assistant_input_ids, target_input_ids, assistant_candidate_ids
)
self.assertTrue(torch.equal(actual_target_ids, expected_target_ids))
def test_get_target_logits(self):
"""Test the conversion of assistant logits to target logits."""
# Assistant logits for IDs 0, 1, 2
assistant_logits = torch.FloatTensor([[[0.1, 0.2, 0.3, 0.4, self.translator.FILTER_VALUE]]]).to(
self.assistant_model_device
) # Shape (1, 1, 5)
# Expected target logits (target_vocab_size = 4)
expected_target_logits = torch.full((1, 1, self.target_vocab_size), self.translator.FILTER_VALUE).to(
self.assistant_model_device
)
expected_target_logits[0, 0, 0] = 0.1 # 'hello'
expected_target_logits[0, 0, 1] = 0.2 # 'world'
expected_target_logits[0, 0, 2] = 0.3 # 'foo'
# The 'bar' token in target vocab remains at -inf
actual_target_logits = self.translator.get_target_logits(assistant_logits)
self.assertTrue(torch.equal(actual_target_logits, expected_target_logits))
class MockTokenizer:
"""A simple mock tokenizer class that supports weak references."""
def __init__(self, vocab=None):
self._vocab = vocab or {}
def get_vocab(self):
return self._vocab
def __call__(self, text, add_special_tokens=True):
# Mock implementation of the __call__ method
tokens = text.split()
input_ids = [self._vocab.get(token, 0) for token in tokens]
return {"input_ids": input_ids}
@require_torch
class TestAssistantVocabTranslatorCache(unittest.TestCase):
def setUp(self):
# Clear the cache before each test
AssistantVocabTranslatorCache._cache.clear()
# Create mock tokenizers with different vocabularies
self.target_tokenizer = MockTokenizer({"hello": 0, "world": 1})
self.assistant_tokenizer = MockTokenizer({"hello": 0, "world": 1, "foo": 2})
self.other_target_tokenizer = MockTokenizer({"foo": 2, "bar": 3})
self.other_assistant_tokenizer = MockTokenizer({"baz": 4, "qux": 5})
self.assistant_model_device = torch_device
self.target_vocab_size = 6
def test_same_instance_for_same_tokenizers(self):
"""Test that the same translator is returned for the same tokenizers."""
translator1 = AssistantVocabTranslatorCache.get_translator(
self.target_tokenizer,
self.assistant_tokenizer,
assistant_model_device=self.assistant_model_device,
target_vocab_size=self.target_vocab_size,
)
translator2 = AssistantVocabTranslatorCache.get_translator(
self.target_tokenizer,
self.assistant_tokenizer,
assistant_model_device=self.assistant_model_device,
target_vocab_size=self.target_vocab_size,
)
self.assertIs(translator1, translator2, "Translators should be cached and identical")
def test_different_instances_for_different_tokenizers(self):
"""Test that different tokenizers produce different translators."""
translator1 = AssistantVocabTranslatorCache.get_translator(
self.target_tokenizer,
self.assistant_tokenizer,
assistant_model_device=self.assistant_model_device,
target_vocab_size=self.target_vocab_size,
)
translator2 = AssistantVocabTranslatorCache.get_translator(
self.other_target_tokenizer,
self.other_assistant_tokenizer,
assistant_model_device=self.assistant_model_device,
target_vocab_size=self.target_vocab_size,
)
self.assertIsNot(translator1, translator2, "Translators should differ for different tokenizers")
def test_cache_with_weakref_key(self):
"""Ensure that the cache uses weak references as keys."""
initial_cache_size = len(AssistantVocabTranslatorCache._cache)
target_tokenizer = MockTokenizer({"hello": 0})
assistant_tokenizer = MockTokenizer({"hello": 0})
# Store translator in a local variable to avoid it being kept alive
translator = AssistantVocabTranslatorCache.get_translator(
target_tokenizer,
assistant_tokenizer,
assistant_model_device=self.assistant_model_device,
target_vocab_size=self.target_vocab_size,
)
self.assertEqual(len(AssistantVocabTranslatorCache._cache), initial_cache_size + 1)
# Delete all strong references
del target_tokenizer
del assistant_tokenizer
del translator
# Force garbage collection
gc.collect()
# Call cleanup to remove dead entries
AssistantVocabTranslatorCache.cleanup()
# The cache size remains increased due to strong references
self.assertEqual(len(AssistantVocabTranslatorCache._cache), initial_cache_size + 1)
def test_weakref_cache_cleanup(self):
"""Test that the cache cleans up translators when tokenizers are garbage collected."""
def create_translator():
target_tokenizer = MockTokenizer({"hello": 0})
assistant_tokenizer = MockTokenizer({"hello": 0})
translator = AssistantVocabTranslatorCache.get_translator(
target_tokenizer,
assistant_tokenizer,
assistant_model_device=self.assistant_model_device,
target_vocab_size=self.target_vocab_size,
)
# Create weak references before returning
refs = (weakref.ref(translator), weakref.ref(target_tokenizer), weakref.ref(assistant_tokenizer))
# Remove strong references inside the function
del target_tokenizer
del assistant_tokenizer
del translator
return refs
translator_ref, target_ref, assistant_ref = create_translator()
# Force garbage collection
gc.collect()
# Call cleanup to remove dead entries
AssistantVocabTranslatorCache.cleanup()
# The tokenizers and translator are not garbage collected due to strong references
self.assertIsNotNone(target_ref(), "Target tokenizer should still be alive due to strong references")
self.assertIsNotNone(assistant_ref(), "Assistant tokenizer should still be alive due to strong references")
self.assertIsNotNone(translator_ref(), "Translator should still be alive due to strong references")
@require_torch
class TestUniversalSpeculativeDecoding(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.target_name = "hf-internal-testing/tiny-random-LlamaForCausalLM"
cls.assistant_name = "hf-internal-testing/tiny-random-PhiForCausalLM"
def setUp(self):
self.target_tokenizer = AutoTokenizer.from_pretrained(self.target_name)
self.target_config = AutoConfig.from_pretrained(self.target_name)
self.assistant_model = AutoModelForCausalLM.from_pretrained(self.assistant_name).to(torch_device)
self.assistant_tokenizer = AutoTokenizer.from_pretrained(self.assistant_name)
self.generation_config = GenerationConfig()
# Ensure required tokens exist
if self.target_tokenizer.pad_token_id is None:
self.target_tokenizer.pad_token_id = self.target_tokenizer.eos_token_id
if self.target_tokenizer.bos_token_id is None:
self.target_tokenizer.bos_token_id = self.target_tokenizer.eos_token_id
if self.assistant_tokenizer.pad_token_id is None:
self.assistant_tokenizer.pad_token_id = self.assistant_tokenizer.eos_token_id
if self.target_tokenizer.bos_token_id is None:
self.assistant_tokenizer.bos_token_id = self.assistant_tokenizer.eos_token_id
self.input_ids = torch.tensor([[1, 2, 3]]).to(torch_device)
self.model_kwargs = {
"attention_mask": torch.ones_like(self.input_ids).to(torch_device),
}
atm_translator = AssistantVocabTranslatorCache.get_translator(
self.target_tokenizer, self.assistant_tokenizer, self.target_config.vocab_size, torch_device
)
self.generator = UniversalSpeculativeDecodingGenerator(
input_ids=self.input_ids,
assistant_model=self.assistant_model,
target_tokenizer=self.target_tokenizer,
assistant_tokenizer=self.assistant_tokenizer,
generation_config=self.generation_config,
model_kwargs=self.model_kwargs,
atm_translator=atm_translator,
)
def test_basic_generation(self):
"""Test basic speculative decoding works"""
input_text = "The quick brown fox"
input_ids = self.target_tokenizer.encode(input_text, return_tensors="pt")
self.generator.input_ids = input_ids
candidates, scores = self.generator.get_candidates(input_ids)
self.assertIsNotNone(candidates)
self.assertIsNotNone(scores)
self.assertTrue(torch.is_tensor(candidates))
self.assertTrue(torch.is_tensor(scores))
def test_mismatched_vocabularies(self):
"""Test handling of mismatched vocabularies between models"""
# Create input with tokens present in main but not assistant vocab
# Find a token that is not in the assistant tokenizer but in
# the main tokenizer.
missing_token = next(
token
for token in self.target_tokenizer.get_vocab()
if token not in self.assistant_tokenizer.get_vocab()
and token not in self.target_tokenizer.all_special_tokens
and "reserved_" not in token
)
input_ids = torch.tensor([[self.target_tokenizer.convert_tokens_to_ids(missing_token)]])
self.generator.input_ids = input_ids
candidates, scores = self.generator.get_candidates(input_ids)
self.assertIsNotNone(candidates)
def test_speculation_depth(self):
"""Test different speculation depths"""
input_ids = self.target_tokenizer.encode("Test text", return_tensors="pt")
self.generator.input_ids = input_ids
for depth in [1, 8, 17]:
self.generator.num_assistant_tokens = depth
candidates, scores = self.generator.get_candidates(input_ids)
self.assertLessEqual(candidates.shape[1] - input_ids.shape[1], depth)
def test_device_consistency(self):
"""Test handling of inputs on different devices"""
input_ids = torch.tensor([[1, 2, 3]]).to(torch_device)
self.generator.input_ids = input_ids
candidates, _ = self.generator.get_candidates(input_ids)
self.assertEqual(candidates.device, input_ids.device)
def test_usd_vs_vanilla_sampling(cls):
"""Test that USD matches vanilla sampling with temperature set to nearly 0"""
prompt = "Test text"
pipe_usd = pipeline("text-generation", model=cls.target_name, assistant_model=cls.assistant_name)
pipe_usd_output = pipe_usd(prompt, max_new_tokens=5, do_sample=True, temperature=1e-9) # Nearly 0 temperature
usd_text = pipe_usd_output[0]["generated_text"]
pipe_vanilla = pipeline(
"text-generation",
model=cls.target_name,
)
pipe_vanilla_output = pipe_vanilla(prompt, max_new_tokens=5, do_sample=False)
vanilla_text = pipe_vanilla_output[0]["generated_text"]
# Assert that the outputs match
cls.assertEqual(usd_text, vanilla_text)