Universal Speculative Decoding CandidateGenerator (#35029)
* move `TestAssistedCandidateGeneratorDifferentTokenizers` into a new testing file * refactor * NOTHING. add space to rerun github actions tests * remove it... * `UniversalSpeculativeDecodingGenerator` * Use `UniversalSpeculativeDecodingGenerator` when `generation_config.do_sample=True` * assistant tokenizes only the target's new suffix * formatting * fix code * fix code * formatting * add `TestGenerateWithDifferentModels` * `TestGenerateWithDifferentModels` parameterize on `do_sample` * `AssistantVocabMapping` & `AssistantVocabMappingCache` * formatting * `AssistantToTargetTranslator`: `get_target_input_ids` & `get_target_logits` * improve `_get_assistant_to_target_input_ids` & formatting * renaming * WIP: debugging `min_new_tokens` * fix get_target_ids * `UniversalSpeculativeDecodingGenerator` * assistant tokenizes only the target's new suffix * formatting * fix code * fix code * formatting * `TestGenerateWithDifferentModels` parameterize on `do_sample` * `AssistantVocabMapping` & `AssistantVocabMappingCache` * formatting * `AssistantToTargetTranslator`: `get_target_input_ids` & `get_target_logits` * improve `_get_assistant_to_target_input_ids` & formatting * renaming * WIP: debugging `min_new_tokens` * fix get_target_ids * fix device issue * fix get_assistant_input_ids * add `TestAssistedCandidateGeneratorDifferentTokenizers` * formatting * `AssistantVocabTranslatorCache` refactor & tests * revert changes in `src/transformers/generation/logits_process.py` * refactor `AssistedCandidateGenerator` * refactor `AssistedCandidateGeneratorDifferentTokenizers` * formatting * refactor `UniversalSpeculativeDecodingGenerator` * fix negative value for max_new_tokens * fix generation length target + attention_mask vs. assistant + attent * fix device * fix negative max_new_tokens bug * fix UAG * minor * formatting * `AssistedCandidateGeneratorDifferentTokenizers` `lookbehind`s init * resolve conflict & formatting * rerun CI tests * remove space... * remove old code * fix candidate_input_ids device * minor * formatting * Fix prepare + apply (#7) * fix prepare + apply * move to cpu * simplity suppress_tokens * fix bugs and refacatoring * device move * handle self.config.vocab_size > len(target_tokenizer.get_vocab()) * no need to normalize in candidate_generator * address Nadav's comments + minor * optimize device move + SuppressTokensLogitsProcessor * AssistantToTargetTranslator, SuppressTokensLogitsProcessor and tokenizers mapping improvements * padding size * padding improvement * fix and simplify get_target_logits * renaming in get_target_logits * minor * add filter_value and suppress_tokens_id * style + rename * remove TODO * restore original SelectTokensLogitsProcessor with modification * fix style * fix _update_past_and_masks and optimize code * remove assistant_vocab_size arg * fix attention_mask * call _prepare_attention_mask also if not has_past_key_values * handling attention mask for first generation * comment * restore test * remove SelectTokensLogitsProcessor * _update_past_and_masks implementation for USD * Add unittests for Universal Assisted generation * fix style * update tests * Remove unused import and fix `test_speculation_depth` test * exclude special and reserved tokens from tokenizer for UAG * mv `test_universal_assisted_generation.py` to `generation/test_candidate_generator.py` * Remove unused imports and fix style using `make style` (#9) * formatting * Swap gated `meta-llama/llama-3.2` with `allenai/llama` (#10) * Fix space sign disagreement (#12) * default values for AssistantToTargetTranslator fileds * fix space sign * minor * fix test + style * Default values for some fields of assistant to target translator (#11) * default values for AssistantToTargetTranslator fileds * fix * add support to empty logit_processors * Update candidate_generator.py (#15) fix typo * BUG fix in _prepare_assistant_input_ids (#14) * fix _prepare_assistant_input_ids * target_to_assistant_input_ids * Update src/transformers/generation/candidate_generator.py Co-authored-by: Nadav Timor <nadav.timor@weizmann.ac.il> --------- Co-authored-by: Nadav Timor <nadav.timor@weizmann.ac.il> * typo (`target_to_assistant_input_ids`) * formatting * merge upstream/main * Fix minor review comments (#16) * Fix: `token_ids.to(torch.int64)` (#18) * tok ids to `torch.int64` (reference: https://huggingface.co/docs/transformers.js/en/api/tokenizers) * `LongTensor` * fix dtype * `assistant_input_ids.to(dtype=torch.long)` * Remove unused import from test_candidate_generator.py * Remove unused import from test_candidate_generator.py * Remove `numpy` import * resolve pr comments (#19) * `AssistantToTargetTranslator` docstring * (per gante's comment) `filter_value` and `suppress_tokens_id` to class constants * update `AssistantToTargetTranslator` docstring * (gante's comment) replace `match-case` * formatting * Fix Joao's comments (#21) * remove threading * fix logits_processor * fix test device * fix style (#23) * Move atm (#24) * move AssistantToTargetTranslator * fixup * fix logit_processor * add atm_translator test * refactor test * remove threading from test * add require_torch in tests * move AssistantVocabTranslatorCache + add tests * ruff fix --------- Co-authored-by: jmamou <jonathan.mamou@intel.com> Co-authored-by: Gaurav <gauravj@d-matrix.ai> Co-authored-by: Gaurav Jain <gaurjain14@gmail.com> Co-authored-by: gauravjain14 <41287729+gauravjain14@users.noreply.github.com>
This commit is contained in:
@@ -14,6 +14,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
|
import weakref
|
||||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
|
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -27,7 +28,7 @@ if is_sklearn_available():
|
|||||||
|
|
||||||
from ..cache_utils import DynamicCache
|
from ..cache_utils import DynamicCache
|
||||||
from ..pytorch_utils import isin_mps_friendly
|
from ..pytorch_utils import isin_mps_friendly
|
||||||
from .logits_process import LogitsProcessorList, MinLengthLogitsProcessor
|
from .logits_process import LogitsProcessorList, MinLengthLogitsProcessor, SuppressTokensLogitsProcessor
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -283,18 +284,21 @@ class AssistedCandidateGenerator(CandidateGenerator):
|
|||||||
min_new_tokens = max(min(max_new_tokens, self.main_model_min_length - new_cur_len), 0)
|
min_new_tokens = max(min(max_new_tokens, self.main_model_min_length - new_cur_len), 0)
|
||||||
return min_new_tokens, max_new_tokens
|
return min_new_tokens, max_new_tokens
|
||||||
|
|
||||||
def _update_past_and_masks(self, input_ids: torch.LongTensor, remove_from_pkv: int = 0) -> bool:
|
def _update_past_and_masks(
|
||||||
|
self, input_ids: torch.LongTensor, remove_from_pkv: int = 0, num_added_tokens: int = 1
|
||||||
|
) -> bool:
|
||||||
"""Update past key values and attention masks for subsequent generation rounds."""
|
"""Update past key values and attention masks for subsequent generation rounds."""
|
||||||
has_past_key_values = self.assistant_kwargs.get("past_key_values", None) is not None
|
has_past_key_values = self.assistant_kwargs.get("past_key_values", None) is not None
|
||||||
if has_past_key_values:
|
if has_past_key_values:
|
||||||
new_cache_size = input_ids.shape[-1] - 1 - remove_from_pkv
|
new_cache_size = input_ids.shape[-1] - 1 - remove_from_pkv
|
||||||
self.assistant_kwargs["past_key_values"] = _crop_past_key_values(
|
self.assistant_kwargs["past_key_values"] = _crop_past_key_values(
|
||||||
self.assistant_model, self.assistant_kwargs["past_key_values"], new_cache_size - 1
|
self.assistant_model, self.assistant_kwargs["past_key_values"], new_cache_size - num_added_tokens
|
||||||
)
|
)
|
||||||
self.assistant_kwargs = _prepare_attention_mask(
|
self.assistant_kwargs = _prepare_attention_mask(
|
||||||
self.assistant_kwargs, input_ids.shape[-1], self.assistant_model.config.is_encoder_decoder
|
self.assistant_kwargs, input_ids.shape[-1], self.assistant_model.config.is_encoder_decoder
|
||||||
)
|
)
|
||||||
self.assistant_kwargs = _prepare_token_type_ids(self.assistant_kwargs, input_ids.shape[-1])
|
self.assistant_kwargs = _prepare_token_type_ids(self.assistant_kwargs, input_ids.shape[-1])
|
||||||
|
|
||||||
return has_past_key_values
|
return has_past_key_values
|
||||||
|
|
||||||
def _prepare_generation_args(self, input_ids: torch.LongTensor, min_new_tokens: int, max_new_tokens: int) -> Dict:
|
def _prepare_generation_args(self, input_ids: torch.LongTensor, min_new_tokens: int, max_new_tokens: int) -> Dict:
|
||||||
@@ -608,6 +612,290 @@ class AssistedCandidateGeneratorDifferentTokenizers(AssistedCandidateGenerator):
|
|||||||
return new_target_ids
|
return new_target_ids
|
||||||
|
|
||||||
|
|
||||||
|
class AssistantToTargetTranslator:
|
||||||
|
"""
|
||||||
|
Translates token ids and logits between assistant and target model vocabularies. This class is used to handle
|
||||||
|
vocabulary mismatches when using different tokenizers for the assistant and target models in speculative decoding,
|
||||||
|
as introduced in the paper "Lossless Speculative Decoding Algorithms for Heterogeneous Vocabularies"
|
||||||
|
(https://www.arxiv.org/abs/2502.05202).
|
||||||
|
It maintains mappings between the two vocabularies and handles token/logit conversion.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
target_tokenizer (`PreTrainedTokenizerBase`):
|
||||||
|
The tokenizer used by the target (main) model.
|
||||||
|
assistant_tokenizer (`PreTrainedTokenizerBase`):
|
||||||
|
The tokenizer used by the assistant model.
|
||||||
|
assistant_model_device (`str`, defaults to "cpu"):
|
||||||
|
The device where the assistant model is located. Used for placing tensors.
|
||||||
|
target_vocab_size (`int`, *optional*):
|
||||||
|
The size of the target model's vocabulary. If not provided, will be inferred from the target tokenizer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
FILTER_VALUE: float = -float("Inf") # The value used to filter out unmapped tokens in the logits.
|
||||||
|
SUPPRESS_TOKEN_ID: int = -1 # The ID used to mark suppressed tokens in the mapping.
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
target_tokenizer: "PreTrainedTokenizerBase",
|
||||||
|
assistant_tokenizer: "PreTrainedTokenizerBase",
|
||||||
|
target_vocab_size: int, # required since target_vocab_size can be different from the length of target_tokenizer.get_vocab()
|
||||||
|
assistant_model_device: str = "cpu",
|
||||||
|
):
|
||||||
|
self._target_tokenizer: "PreTrainedTokenizerBase" = target_tokenizer
|
||||||
|
self._assistant_tokenizer: "PreTrainedTokenizerBase" = assistant_tokenizer
|
||||||
|
self._assistant_model_device: str = assistant_model_device
|
||||||
|
self.target_vocab_size: int = target_vocab_size
|
||||||
|
self._assistant_to_target_input_ids, self.target_to_assistant_input_ids = (
|
||||||
|
self._get_assistant_to_target_input_ids()
|
||||||
|
)
|
||||||
|
self._suppress_input_ids: list[int] = self._get_suppress_input_ids()
|
||||||
|
self.logits_processors: Optional[LogitsProcessorList] = None
|
||||||
|
if len(self._suppress_input_ids) > 0:
|
||||||
|
# len(self._suppress_input_ids) = 0 if the assistant vocab is a subset of the target vocab
|
||||||
|
self.logits_processors = LogitsProcessorList(
|
||||||
|
[SuppressTokensLogitsProcessor(self._get_suppress_input_ids(), self._assistant_model_device)]
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_assistant_to_target_input_ids(self):
|
||||||
|
target_vocab = self._target_tokenizer.get_vocab()
|
||||||
|
assistant_vocab = self._assistant_tokenizer.get_vocab()
|
||||||
|
|
||||||
|
space_str = " "
|
||||||
|
target_space_ids = self._target_tokenizer(space_str, add_special_tokens=False)["input_ids"]
|
||||||
|
if len(target_space_ids) > 0:
|
||||||
|
target_space_sign = self._target_tokenizer.convert_ids_to_tokens(target_space_ids)[0][0]
|
||||||
|
|
||||||
|
assistant_space_ids = self._assistant_tokenizer(space_str, add_special_tokens=False)["input_ids"]
|
||||||
|
if len(assistant_space_ids) > 0:
|
||||||
|
assistant_space_sign = self._assistant_tokenizer.convert_ids_to_tokens(assistant_space_ids)[0][0]
|
||||||
|
|
||||||
|
if target_space_sign != assistant_space_sign:
|
||||||
|
# If the assistant tokenizer has a different space sign than the target tokenizer,
|
||||||
|
# we need to replace the assistant space sign with the target space sign in the assistant_vocab.
|
||||||
|
assistant_vocab = {
|
||||||
|
(
|
||||||
|
tok.replace(assistant_space_sign, target_space_sign, 1)
|
||||||
|
if tok.startswith(assistant_space_sign)
|
||||||
|
else tok
|
||||||
|
): idx
|
||||||
|
for tok, idx in assistant_vocab.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
max_assistant_index = max(assistant_vocab.values())
|
||||||
|
assistant_to_target_input_ids = torch.full((max_assistant_index + 1,), self.SUPPRESS_TOKEN_ID, dtype=int)
|
||||||
|
target_to_assistant_input_ids: Dict[int, int] = {}
|
||||||
|
for tok, assistant_id in assistant_vocab.items():
|
||||||
|
target_id = target_vocab.get(tok)
|
||||||
|
if target_id is not None:
|
||||||
|
assistant_to_target_input_ids[assistant_id] = target_id
|
||||||
|
target_to_assistant_input_ids[target_id] = assistant_id
|
||||||
|
return assistant_to_target_input_ids.to(self._assistant_model_device), target_to_assistant_input_ids
|
||||||
|
|
||||||
|
def _get_suppress_input_ids(self) -> list[int]:
|
||||||
|
"""
|
||||||
|
Get the input ids that are in the assistant vocab but not in the target vocab.
|
||||||
|
"""
|
||||||
|
return torch.where(self._assistant_to_target_input_ids == self.SUPPRESS_TOKEN_ID)[0]
|
||||||
|
|
||||||
|
def get_target_ids(
|
||||||
|
self, assistant_input_ids, target_input_ids, assistant_candidate_ids: torch.LongTensor
|
||||||
|
) -> torch.LongTensor:
|
||||||
|
"""
|
||||||
|
Return the target candidate ids that correspond to the assistant candidate ids.
|
||||||
|
Note that we have already the target ids for the prompt and we only need to find the target ids for the new tokens.
|
||||||
|
Moreover, assistant ids of the original prompt does not necessarily appear in _assistant_to_target_input_ids.
|
||||||
|
"""
|
||||||
|
|
||||||
|
num_new_tokens = len(assistant_candidate_ids[0]) - assistant_input_ids.shape[1]
|
||||||
|
if num_new_tokens == 0:
|
||||||
|
return target_input_ids
|
||||||
|
else:
|
||||||
|
transformed_slice = self._assistant_to_target_input_ids[assistant_candidate_ids[0, -num_new_tokens:]]
|
||||||
|
return torch.cat((target_input_ids, transformed_slice.unsqueeze(0)), dim=1)
|
||||||
|
|
||||||
|
def get_target_logits(self, assistant_logits: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
|
"""
|
||||||
|
Return the target logits that correspond to the assistant logits.
|
||||||
|
"""
|
||||||
|
|
||||||
|
target_shape: tuple[int, ...] = (*assistant_logits.shape[:-1], self.target_vocab_size)
|
||||||
|
target_logits: torch.FloatTensor = torch.full(target_shape, self.FILTER_VALUE).to(self._assistant_model_device)
|
||||||
|
# Mask for valid indices
|
||||||
|
assistant_indices_mask = self._assistant_to_target_input_ids != self.SUPPRESS_TOKEN_ID
|
||||||
|
# Exclude invalid indices
|
||||||
|
target_logits_supported_indices = self._assistant_to_target_input_ids[assistant_indices_mask]
|
||||||
|
valid_assistant_logits = assistant_logits[..., : self._assistant_to_target_input_ids.shape[0]]
|
||||||
|
|
||||||
|
target_logits[..., target_logits_supported_indices] = valid_assistant_logits[..., assistant_indices_mask]
|
||||||
|
|
||||||
|
return target_logits
|
||||||
|
|
||||||
|
|
||||||
|
class AssistantVocabTranslatorCache:
|
||||||
|
"""
|
||||||
|
Cache for `AssistantToTargetTranslator` instances. The instances are computed at
|
||||||
|
pre-processing time, and this cache allows us to avoid recomputing them.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_cache = weakref.WeakKeyDictionary()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_translator(
|
||||||
|
cls,
|
||||||
|
target_tokenizer: "PreTrainedTokenizerBase",
|
||||||
|
assistant_tokenizer: "PreTrainedTokenizerBase",
|
||||||
|
target_vocab_size: int,
|
||||||
|
assistant_model_device: str = "cpu",
|
||||||
|
) -> AssistantToTargetTranslator:
|
||||||
|
assistant_dict = cls._cache.get(target_tokenizer)
|
||||||
|
if assistant_dict is None:
|
||||||
|
assistant_dict = weakref.WeakKeyDictionary()
|
||||||
|
cls._cache[target_tokenizer] = assistant_dict
|
||||||
|
|
||||||
|
mapping = assistant_dict.get(assistant_tokenizer)
|
||||||
|
if mapping is None:
|
||||||
|
mapping = AssistantToTargetTranslator(
|
||||||
|
target_tokenizer, assistant_tokenizer, target_vocab_size, assistant_model_device
|
||||||
|
)
|
||||||
|
assistant_dict[assistant_tokenizer] = mapping
|
||||||
|
|
||||||
|
return mapping
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def cleanup(cls):
|
||||||
|
"""
|
||||||
|
Clean up dead references in the cache.
|
||||||
|
This removes entries where either the target_tokenizer or assistant_tokenizer
|
||||||
|
has been garbage collected.
|
||||||
|
"""
|
||||||
|
# Remove entries from the outer cache where the target_tokenizer is no longer alive
|
||||||
|
dead_keys = [key for key in cls._cache if key is None]
|
||||||
|
for key in dead_keys:
|
||||||
|
del cls._cache[key]
|
||||||
|
|
||||||
|
# For each assistant_dict, remove entries where assistant_tokenizer is no longer alive
|
||||||
|
for assistant_dict in cls._cache.values():
|
||||||
|
dead_keys = [key for key in assistant_dict if key is None]
|
||||||
|
for key in dead_keys:
|
||||||
|
del assistant_dict[key]
|
||||||
|
|
||||||
|
|
||||||
|
class UniversalSpeculativeDecodingGenerator(AssistedCandidateGeneratorDifferentTokenizers):
|
||||||
|
"""
|
||||||
|
`CandidateGenerator` class to be used for Universal Speculative Decoding (USD): speculative decoding with different tokenizers
|
||||||
|
for the assistant and main models. This class generates candidates through the use of a smaller model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor,
|
||||||
|
assistant_model: "PreTrainedModel",
|
||||||
|
target_tokenizer: "PreTrainedTokenizerBase",
|
||||||
|
assistant_tokenizer: "PreTrainedTokenizerBase",
|
||||||
|
generation_config: "GenerationConfig",
|
||||||
|
model_kwargs: Dict,
|
||||||
|
atm_translator: AssistantToTargetTranslator,
|
||||||
|
inputs_tensor: Optional[torch.Tensor] = None,
|
||||||
|
logits_processor: "LogitsProcessorList" = None,
|
||||||
|
):
|
||||||
|
# Initialize translator before parent class
|
||||||
|
self._atm_translator = atm_translator
|
||||||
|
super().__init__(
|
||||||
|
input_ids,
|
||||||
|
assistant_model,
|
||||||
|
target_tokenizer,
|
||||||
|
assistant_tokenizer,
|
||||||
|
generation_config,
|
||||||
|
model_kwargs,
|
||||||
|
inputs_tensor,
|
||||||
|
logits_processor,
|
||||||
|
)
|
||||||
|
# Track sequence lengths and previous assistant IDs
|
||||||
|
self._target_seq_len_with_candidates: int = 0
|
||||||
|
self._prev_assistant_ids: Optional[torch.LongTensor] = None
|
||||||
|
|
||||||
|
def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]:
|
||||||
|
"""
|
||||||
|
Simplified version of get_candidates that uses the translator cache for token conversion.
|
||||||
|
"""
|
||||||
|
target_input_ids = input_ids.to(self.assistant_model.device)
|
||||||
|
assistant_input_ids, num_added_tokens = self._prepare_assistant_input_ids(target_input_ids)
|
||||||
|
min_new_tokens, max_new_tokens = self._calculate_new_tokens(target_input_ids)
|
||||||
|
|
||||||
|
if max_new_tokens == 0:
|
||||||
|
return input_ids, None
|
||||||
|
|
||||||
|
self._update_past_and_masks(assistant_input_ids, num_added_tokens=num_added_tokens)
|
||||||
|
generation_args = self._prepare_generation_args(assistant_input_ids, min_new_tokens, max_new_tokens)
|
||||||
|
|
||||||
|
# Ensure scores are returned
|
||||||
|
generation_args["generation_config"].output_scores = True
|
||||||
|
generation_args["generation_config"].return_dict_in_generate = True
|
||||||
|
|
||||||
|
# Generate and process outputs using translator
|
||||||
|
if self._atm_translator.logits_processors is not None:
|
||||||
|
generation_args["logits_processor"] = self._atm_translator.logits_processors
|
||||||
|
self._prev_assistant_ids, assistant_candidate_logits = self._generate_candidates(generation_args)
|
||||||
|
|
||||||
|
# Use translator to convert tokens and logits
|
||||||
|
target_candidate_ids = self._atm_translator.get_target_ids(
|
||||||
|
assistant_input_ids, target_input_ids, self._prev_assistant_ids
|
||||||
|
)
|
||||||
|
self._target_seq_len_with_candidates = target_candidate_ids.shape[-1]
|
||||||
|
target_candidate_logits = self._atm_translator.get_target_logits(assistant_candidate_logits)
|
||||||
|
|
||||||
|
return target_candidate_ids, target_candidate_logits
|
||||||
|
|
||||||
|
def _update_past_and_masks(self, assistant_input_ids: torch.LongTensor, num_added_tokens: int = 1) -> bool:
|
||||||
|
if self._prev_assistant_ids is None:
|
||||||
|
# Prepare attention mask for the first generation.
|
||||||
|
# For subsequent generations, the attention mask is updated in super()_update_past_and_masks.
|
||||||
|
self.assistant_kwargs = _prepare_attention_mask(
|
||||||
|
self.assistant_kwargs, assistant_input_ids.shape[-1], self.assistant_model.config.is_encoder_decoder
|
||||||
|
)
|
||||||
|
return super()._update_past_and_masks(assistant_input_ids, num_added_tokens=num_added_tokens)
|
||||||
|
|
||||||
|
def _prepare_assistant_input_ids(self, target_input_ids: torch.LongTensor) -> torch.LongTensor:
|
||||||
|
"""
|
||||||
|
Simplified token conversion that only processes new tokens.
|
||||||
|
"""
|
||||||
|
# Calculate new tokens since last call
|
||||||
|
target_seq_len = target_input_ids.shape[-1]
|
||||||
|
if self._target_seq_len_with_candidates == 0:
|
||||||
|
new_token_count = target_seq_len
|
||||||
|
else:
|
||||||
|
new_token_count = 1
|
||||||
|
target_new_ids = target_input_ids[:, -new_token_count:]
|
||||||
|
|
||||||
|
# Convert the new tokens
|
||||||
|
assistant_new_ids = None
|
||||||
|
if self._target_seq_len_with_candidates > 0:
|
||||||
|
# we have only one new token and we can directly convert it
|
||||||
|
assistant_new_ids = self._atm_translator.target_to_assistant_input_ids.get(target_new_ids[0].item())
|
||||||
|
if assistant_new_ids is None:
|
||||||
|
target_new_text = self.target_tokenizer.batch_decode(
|
||||||
|
target_new_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
|
||||||
|
)
|
||||||
|
assistant_new_ids = self.assistant_tokenizer(
|
||||||
|
target_new_text, add_special_tokens=False, return_tensors="pt"
|
||||||
|
)["input_ids"].to(self.assistant_model.device)
|
||||||
|
else:
|
||||||
|
assistant_new_ids = torch.tensor([[assistant_new_ids]], device=self.assistant_model.device)
|
||||||
|
|
||||||
|
# Update or initialize assistant IDs
|
||||||
|
if self._prev_assistant_ids is None:
|
||||||
|
assistant_input_ids = assistant_new_ids
|
||||||
|
else:
|
||||||
|
tokens_to_remove = self._target_seq_len_with_candidates + 1 - target_seq_len
|
||||||
|
# If the number of new tokens is greater than zero, truncate the previous assistant IDs
|
||||||
|
if tokens_to_remove > 0:
|
||||||
|
self._prev_assistant_ids = self._prev_assistant_ids[:, :-tokens_to_remove]
|
||||||
|
assistant_input_ids = torch.cat([self._prev_assistant_ids, assistant_new_ids], dim=-1)
|
||||||
|
assistant_input_ids = assistant_input_ids.to(dtype=torch.long)
|
||||||
|
|
||||||
|
return assistant_input_ids, len(assistant_new_ids[0])
|
||||||
|
|
||||||
|
|
||||||
class PromptLookupCandidateGenerator(CandidateGenerator):
|
class PromptLookupCandidateGenerator(CandidateGenerator):
|
||||||
"""
|
"""
|
||||||
`CandidateGenerator` class to be used for prompt lookup generation. This class generates candidates by looking up
|
`CandidateGenerator` class to be used for prompt lookup generation. This class generates candidates by looking up
|
||||||
|
|||||||
@@ -26,6 +26,8 @@ import torch.distributed as dist
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
from transformers.generation.candidate_generator import AssistantVocabTranslatorCache
|
||||||
|
|
||||||
from ..cache_utils import (
|
from ..cache_utils import (
|
||||||
Cache,
|
Cache,
|
||||||
DynamicCache,
|
DynamicCache,
|
||||||
@@ -56,6 +58,7 @@ from .candidate_generator import (
|
|||||||
CandidateGenerator,
|
CandidateGenerator,
|
||||||
EarlyExitCandidateGenerator,
|
EarlyExitCandidateGenerator,
|
||||||
PromptLookupCandidateGenerator,
|
PromptLookupCandidateGenerator,
|
||||||
|
UniversalSpeculativeDecodingGenerator,
|
||||||
_crop_past_key_values,
|
_crop_past_key_values,
|
||||||
_prepare_attention_mask,
|
_prepare_attention_mask,
|
||||||
_prepare_token_type_ids,
|
_prepare_token_type_ids,
|
||||||
@@ -858,6 +861,22 @@ class GenerationMixin:
|
|||||||
max_length=generation_config.max_length,
|
max_length=generation_config.max_length,
|
||||||
)
|
)
|
||||||
elif different_tokenizers:
|
elif different_tokenizers:
|
||||||
|
if generation_config.do_sample is True:
|
||||||
|
atm_translator = AssistantVocabTranslatorCache.get_translator(
|
||||||
|
target_tokenizer, assistant_tokenizer, self.config.vocab_size, assistant_model.device
|
||||||
|
)
|
||||||
|
candidate_generator = UniversalSpeculativeDecodingGenerator(
|
||||||
|
input_ids=input_ids,
|
||||||
|
assistant_model=assistant_model,
|
||||||
|
generation_config=generation_config,
|
||||||
|
model_kwargs=model_kwargs,
|
||||||
|
inputs_tensor=inputs_tensor,
|
||||||
|
logits_processor=logits_processor,
|
||||||
|
target_tokenizer=target_tokenizer,
|
||||||
|
assistant_tokenizer=assistant_tokenizer,
|
||||||
|
atm_translator=atm_translator,
|
||||||
|
)
|
||||||
|
elif generation_config.do_sample is False:
|
||||||
candidate_generator = AssistedCandidateGeneratorDifferentTokenizers(
|
candidate_generator = AssistedCandidateGeneratorDifferentTokenizers(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
assistant_model=assistant_model,
|
assistant_model=assistant_model,
|
||||||
@@ -868,6 +887,10 @@ class GenerationMixin:
|
|||||||
target_tokenizer=target_tokenizer,
|
target_tokenizer=target_tokenizer,
|
||||||
assistant_tokenizer=assistant_tokenizer,
|
assistant_tokenizer=assistant_tokenizer,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid value for `do_sample`: expected a boolean, got {type(generation_config.do_sample).__name__}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
candidate_generator = AssistedCandidateGenerator(
|
candidate_generator = AssistedCandidateGenerator(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
@@ -4225,7 +4248,6 @@ class GenerationMixin:
|
|||||||
|
|
||||||
# 1. Fetch candidate sequences from a `CandidateGenerator` and move to the correct device
|
# 1. Fetch candidate sequences from a `CandidateGenerator` and move to the correct device
|
||||||
candidate_input_ids, candidate_logits = candidate_generator.get_candidates(input_ids)
|
candidate_input_ids, candidate_logits = candidate_generator.get_candidates(input_ids)
|
||||||
|
|
||||||
candidate_input_ids = candidate_input_ids.to(self.device)
|
candidate_input_ids = candidate_input_ids.to(self.device)
|
||||||
if candidate_logits is not None:
|
if candidate_logits is not None:
|
||||||
candidate_logits = candidate_logits.to(self.device)
|
candidate_logits = candidate_logits.to(self.device)
|
||||||
|
|||||||
@@ -1,43 +1,325 @@
|
|||||||
|
import gc
|
||||||
import unittest
|
import unittest
|
||||||
|
import weakref
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
import numpy as np
|
import torch
|
||||||
|
|
||||||
from transformers.generation.candidate_generator import AssistedCandidateGeneratorDifferentTokenizers
|
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, GenerationConfig, pipeline
|
||||||
|
from transformers.generation.candidate_generator import (
|
||||||
|
AssistantToTargetTranslator,
|
||||||
|
AssistantVocabTranslatorCache,
|
||||||
|
UniversalSpeculativeDecodingGenerator,
|
||||||
|
)
|
||||||
|
from transformers.testing_utils import require_torch, torch_device
|
||||||
|
|
||||||
|
|
||||||
class TestAssistedCandidateGeneratorDifferentTokenizers(unittest.TestCase):
|
@require_torch
|
||||||
def test_no_intersection(self):
|
class TestAssistantToTargetTranslator(unittest.TestCase):
|
||||||
prompt = np.array([[1, 2, 3]])
|
def setUp(self):
|
||||||
prompt_plus_new_tokens = np.array([[4, 5, 6]])
|
# Create mock tokenizers with predefined vocabularies
|
||||||
result = AssistedCandidateGeneratorDifferentTokenizers._get_tokens_diag(prompt, prompt_plus_new_tokens)
|
self.target_tokenizer = MagicMock()
|
||||||
self.assertEqual(result, (None, None, None))
|
self.assistant_tokenizer = MagicMock()
|
||||||
|
|
||||||
def test_complete_overlap(self):
|
# Define mock vocabularies for the tokenizers
|
||||||
prompt = np.array([[1, 2, 3]])
|
self.target_vocab = {"hello": 0, "world": 1, "foo": 2, "bar": 3}
|
||||||
prompt_plus_new_tokens = np.array([[1, 2, 3, 4, 5]])
|
self.assistant_vocab = {"hello": 0, "world": 1, "foo": 2, "baz": 4}
|
||||||
discrep_length, new_tokens_only, discrep_only = AssistedCandidateGeneratorDifferentTokenizers._get_tokens_diag(
|
|
||||||
prompt, prompt_plus_new_tokens
|
self.target_tokenizer.get_vocab.return_value = self.target_vocab
|
||||||
|
self.assistant_tokenizer.get_vocab.return_value = self.assistant_vocab
|
||||||
|
self.assistant_model_device = torch_device
|
||||||
|
self.target_vocab_size = 6
|
||||||
|
|
||||||
|
# Instantiate the class under test
|
||||||
|
self.translator = AssistantToTargetTranslator(
|
||||||
|
target_tokenizer=self.target_tokenizer,
|
||||||
|
assistant_tokenizer=self.assistant_tokenizer,
|
||||||
|
assistant_model_device=self.assistant_model_device,
|
||||||
|
target_vocab_size=self.target_vocab_size,
|
||||||
)
|
)
|
||||||
self.assertEqual(discrep_length, 0)
|
|
||||||
np.testing.assert_array_equal(new_tokens_only, np.array([[4, 5]]))
|
|
||||||
np.testing.assert_array_equal(discrep_only, np.array([[]]))
|
|
||||||
|
|
||||||
def test_partial_overlap(self):
|
def test_get_assistant_to_target_input_ids(self):
|
||||||
prompt = np.array([[1, 2, 3]])
|
"""Test the mapping from assistant tokens to target tokens."""
|
||||||
prompt_plus_new_tokens = np.array([[2, 3, 4, 5]])
|
expected_mapping = [0, 1, 2, self.translator.SUPPRESS_TOKEN_ID, self.translator.SUPPRESS_TOKEN_ID]
|
||||||
discrep_length, new_tokens_only, discrep_only = AssistedCandidateGeneratorDifferentTokenizers._get_tokens_diag(
|
actual_mapping = self.translator._assistant_to_target_input_ids.tolist()
|
||||||
prompt, prompt_plus_new_tokens
|
self.assertEqual(actual_mapping, expected_mapping)
|
||||||
)
|
|
||||||
self.assertEqual(discrep_length, 0)
|
|
||||||
np.testing.assert_array_equal(new_tokens_only, np.array([[4, 5]]))
|
|
||||||
np.testing.assert_array_equal(discrep_only, np.array([[]]))
|
|
||||||
|
|
||||||
def test_no_new_tokens(self):
|
def test_get_suppress_input_ids(self):
|
||||||
prompt = np.array([[1, 2, 3]])
|
"""Test the suppression of assistant input IDs not present in the target vocabulary."""
|
||||||
prompt_plus_new_tokens = np.array([[1, 2, 3]])
|
expected_suppress_ids = [3, 4]
|
||||||
discrep_length, new_tokens_only, discrep_only = AssistedCandidateGeneratorDifferentTokenizers._get_tokens_diag(
|
actual_suppress_ids = self.translator._get_suppress_input_ids().tolist()
|
||||||
prompt, prompt_plus_new_tokens
|
self.assertEqual(actual_suppress_ids, expected_suppress_ids)
|
||||||
|
|
||||||
|
def test_get_target_ids(self):
|
||||||
|
"""Test the translation of assistant candidate IDs to target candidate IDs."""
|
||||||
|
assistant_input_ids = torch.LongTensor([[0, 1, 2]]).to(
|
||||||
|
self.assistant_model_device
|
||||||
|
) # 'hello world foo' in assistant tokenizer
|
||||||
|
target_input_ids = torch.LongTensor([[0, 1, 2]]).to(
|
||||||
|
self.assistant_model_device
|
||||||
|
) # 'hello world foo' in target tokenizer
|
||||||
|
assistant_candidate_ids = torch.LongTensor([[0, 1, 2, 4]]).to(
|
||||||
|
self.assistant_model_device
|
||||||
|
) # 'hello world foo baz' in assistant tokenizer
|
||||||
|
|
||||||
|
expected_target_ids = torch.LongTensor(
|
||||||
|
[[0, 1, 2, self.translator.SUPPRESS_TOKEN_ID]]
|
||||||
|
).to(
|
||||||
|
self.assistant_model_device
|
||||||
|
) # 'hello world foo baz' in target tokenizer (baz is mapped to self.translator.suppress_tokens_id since it does not exist in target vocab)
|
||||||
|
|
||||||
|
actual_target_ids = self.translator.get_target_ids(
|
||||||
|
assistant_input_ids, target_input_ids, assistant_candidate_ids
|
||||||
)
|
)
|
||||||
self.assertEqual(discrep_length, 0)
|
self.assertTrue(torch.equal(actual_target_ids, expected_target_ids))
|
||||||
np.testing.assert_array_equal(new_tokens_only, np.array([[]]))
|
|
||||||
np.testing.assert_array_equal(discrep_only, np.array([[]]))
|
def test_get_target_logits(self):
|
||||||
|
"""Test the conversion of assistant logits to target logits."""
|
||||||
|
# Assistant logits for IDs 0, 1, 2
|
||||||
|
assistant_logits = torch.FloatTensor([[[0.1, 0.2, 0.3, 0.4, self.translator.FILTER_VALUE]]]).to(
|
||||||
|
self.assistant_model_device
|
||||||
|
) # Shape (1, 1, 5)
|
||||||
|
|
||||||
|
# Expected target logits (target_vocab_size = 4)
|
||||||
|
expected_target_logits = torch.full((1, 1, self.target_vocab_size), self.translator.FILTER_VALUE).to(
|
||||||
|
self.assistant_model_device
|
||||||
|
)
|
||||||
|
expected_target_logits[0, 0, 0] = 0.1 # 'hello'
|
||||||
|
expected_target_logits[0, 0, 1] = 0.2 # 'world'
|
||||||
|
expected_target_logits[0, 0, 2] = 0.3 # 'foo'
|
||||||
|
# The 'bar' token in target vocab remains at -inf
|
||||||
|
|
||||||
|
actual_target_logits = self.translator.get_target_logits(assistant_logits)
|
||||||
|
self.assertTrue(torch.equal(actual_target_logits, expected_target_logits))
|
||||||
|
|
||||||
|
|
||||||
|
class MockTokenizer:
|
||||||
|
"""A simple mock tokenizer class that supports weak references."""
|
||||||
|
|
||||||
|
def __init__(self, vocab=None):
|
||||||
|
self._vocab = vocab or {}
|
||||||
|
|
||||||
|
def get_vocab(self):
|
||||||
|
return self._vocab
|
||||||
|
|
||||||
|
def __call__(self, text, add_special_tokens=True):
|
||||||
|
# Mock implementation of the __call__ method
|
||||||
|
tokens = text.split()
|
||||||
|
input_ids = [self._vocab.get(token, 0) for token in tokens]
|
||||||
|
return {"input_ids": input_ids}
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
class TestAssistantVocabTranslatorCache(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
# Clear the cache before each test
|
||||||
|
AssistantVocabTranslatorCache._cache.clear()
|
||||||
|
# Create mock tokenizers with different vocabularies
|
||||||
|
self.target_tokenizer = MockTokenizer({"hello": 0, "world": 1})
|
||||||
|
self.assistant_tokenizer = MockTokenizer({"hello": 0, "world": 1, "foo": 2})
|
||||||
|
self.other_target_tokenizer = MockTokenizer({"foo": 2, "bar": 3})
|
||||||
|
self.other_assistant_tokenizer = MockTokenizer({"baz": 4, "qux": 5})
|
||||||
|
self.assistant_model_device = torch_device
|
||||||
|
self.target_vocab_size = 6
|
||||||
|
|
||||||
|
def test_same_instance_for_same_tokenizers(self):
|
||||||
|
"""Test that the same translator is returned for the same tokenizers."""
|
||||||
|
translator1 = AssistantVocabTranslatorCache.get_translator(
|
||||||
|
self.target_tokenizer,
|
||||||
|
self.assistant_tokenizer,
|
||||||
|
assistant_model_device=self.assistant_model_device,
|
||||||
|
target_vocab_size=self.target_vocab_size,
|
||||||
|
)
|
||||||
|
translator2 = AssistantVocabTranslatorCache.get_translator(
|
||||||
|
self.target_tokenizer,
|
||||||
|
self.assistant_tokenizer,
|
||||||
|
assistant_model_device=self.assistant_model_device,
|
||||||
|
target_vocab_size=self.target_vocab_size,
|
||||||
|
)
|
||||||
|
self.assertIs(translator1, translator2, "Translators should be cached and identical")
|
||||||
|
|
||||||
|
def test_different_instances_for_different_tokenizers(self):
|
||||||
|
"""Test that different tokenizers produce different translators."""
|
||||||
|
translator1 = AssistantVocabTranslatorCache.get_translator(
|
||||||
|
self.target_tokenizer,
|
||||||
|
self.assistant_tokenizer,
|
||||||
|
assistant_model_device=self.assistant_model_device,
|
||||||
|
target_vocab_size=self.target_vocab_size,
|
||||||
|
)
|
||||||
|
translator2 = AssistantVocabTranslatorCache.get_translator(
|
||||||
|
self.other_target_tokenizer,
|
||||||
|
self.other_assistant_tokenizer,
|
||||||
|
assistant_model_device=self.assistant_model_device,
|
||||||
|
target_vocab_size=self.target_vocab_size,
|
||||||
|
)
|
||||||
|
self.assertIsNot(translator1, translator2, "Translators should differ for different tokenizers")
|
||||||
|
|
||||||
|
def test_cache_with_weakref_key(self):
|
||||||
|
"""Ensure that the cache uses weak references as keys."""
|
||||||
|
initial_cache_size = len(AssistantVocabTranslatorCache._cache)
|
||||||
|
target_tokenizer = MockTokenizer({"hello": 0})
|
||||||
|
assistant_tokenizer = MockTokenizer({"hello": 0})
|
||||||
|
|
||||||
|
# Store translator in a local variable to avoid it being kept alive
|
||||||
|
translator = AssistantVocabTranslatorCache.get_translator(
|
||||||
|
target_tokenizer,
|
||||||
|
assistant_tokenizer,
|
||||||
|
assistant_model_device=self.assistant_model_device,
|
||||||
|
target_vocab_size=self.target_vocab_size,
|
||||||
|
)
|
||||||
|
self.assertEqual(len(AssistantVocabTranslatorCache._cache), initial_cache_size + 1)
|
||||||
|
|
||||||
|
# Delete all strong references
|
||||||
|
del target_tokenizer
|
||||||
|
del assistant_tokenizer
|
||||||
|
del translator
|
||||||
|
|
||||||
|
# Force garbage collection
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
# Call cleanup to remove dead entries
|
||||||
|
AssistantVocabTranslatorCache.cleanup()
|
||||||
|
|
||||||
|
# The cache size remains increased due to strong references
|
||||||
|
self.assertEqual(len(AssistantVocabTranslatorCache._cache), initial_cache_size + 1)
|
||||||
|
|
||||||
|
def test_weakref_cache_cleanup(self):
|
||||||
|
"""Test that the cache cleans up translators when tokenizers are garbage collected."""
|
||||||
|
|
||||||
|
def create_translator():
|
||||||
|
target_tokenizer = MockTokenizer({"hello": 0})
|
||||||
|
assistant_tokenizer = MockTokenizer({"hello": 0})
|
||||||
|
translator = AssistantVocabTranslatorCache.get_translator(
|
||||||
|
target_tokenizer,
|
||||||
|
assistant_tokenizer,
|
||||||
|
assistant_model_device=self.assistant_model_device,
|
||||||
|
target_vocab_size=self.target_vocab_size,
|
||||||
|
)
|
||||||
|
# Create weak references before returning
|
||||||
|
refs = (weakref.ref(translator), weakref.ref(target_tokenizer), weakref.ref(assistant_tokenizer))
|
||||||
|
# Remove strong references inside the function
|
||||||
|
del target_tokenizer
|
||||||
|
del assistant_tokenizer
|
||||||
|
del translator
|
||||||
|
return refs
|
||||||
|
|
||||||
|
translator_ref, target_ref, assistant_ref = create_translator()
|
||||||
|
|
||||||
|
# Force garbage collection
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
# Call cleanup to remove dead entries
|
||||||
|
AssistantVocabTranslatorCache.cleanup()
|
||||||
|
|
||||||
|
# The tokenizers and translator are not garbage collected due to strong references
|
||||||
|
self.assertIsNotNone(target_ref(), "Target tokenizer should still be alive due to strong references")
|
||||||
|
self.assertIsNotNone(assistant_ref(), "Assistant tokenizer should still be alive due to strong references")
|
||||||
|
self.assertIsNotNone(translator_ref(), "Translator should still be alive due to strong references")
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
class TestUniversalSpeculativeDecoding(unittest.TestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.target_name = "hf-internal-testing/tiny-random-LlamaForCausalLM"
|
||||||
|
cls.assistant_name = "hf-internal-testing/tiny-random-PhiForCausalLM"
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.target_tokenizer = AutoTokenizer.from_pretrained(self.target_name)
|
||||||
|
self.target_config = AutoConfig.from_pretrained(self.target_name)
|
||||||
|
self.assistant_model = AutoModelForCausalLM.from_pretrained(self.assistant_name).to(torch_device)
|
||||||
|
self.assistant_tokenizer = AutoTokenizer.from_pretrained(self.assistant_name)
|
||||||
|
|
||||||
|
self.generation_config = GenerationConfig()
|
||||||
|
|
||||||
|
# Ensure required tokens exist
|
||||||
|
if self.target_tokenizer.pad_token_id is None:
|
||||||
|
self.target_tokenizer.pad_token_id = self.target_tokenizer.eos_token_id
|
||||||
|
if self.target_tokenizer.bos_token_id is None:
|
||||||
|
self.target_tokenizer.bos_token_id = self.target_tokenizer.eos_token_id
|
||||||
|
if self.assistant_tokenizer.pad_token_id is None:
|
||||||
|
self.assistant_tokenizer.pad_token_id = self.assistant_tokenizer.eos_token_id
|
||||||
|
if self.target_tokenizer.bos_token_id is None:
|
||||||
|
self.assistant_tokenizer.bos_token_id = self.assistant_tokenizer.eos_token_id
|
||||||
|
|
||||||
|
self.input_ids = torch.tensor([[1, 2, 3]]).to(torch_device)
|
||||||
|
self.model_kwargs = {
|
||||||
|
"attention_mask": torch.ones_like(self.input_ids).to(torch_device),
|
||||||
|
}
|
||||||
|
|
||||||
|
atm_translator = AssistantVocabTranslatorCache.get_translator(
|
||||||
|
self.target_tokenizer, self.assistant_tokenizer, self.target_config.vocab_size, torch_device
|
||||||
|
)
|
||||||
|
self.generator = UniversalSpeculativeDecodingGenerator(
|
||||||
|
input_ids=self.input_ids,
|
||||||
|
assistant_model=self.assistant_model,
|
||||||
|
target_tokenizer=self.target_tokenizer,
|
||||||
|
assistant_tokenizer=self.assistant_tokenizer,
|
||||||
|
generation_config=self.generation_config,
|
||||||
|
model_kwargs=self.model_kwargs,
|
||||||
|
atm_translator=atm_translator,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_basic_generation(self):
|
||||||
|
"""Test basic speculative decoding works"""
|
||||||
|
input_text = "The quick brown fox"
|
||||||
|
input_ids = self.target_tokenizer.encode(input_text, return_tensors="pt")
|
||||||
|
self.generator.input_ids = input_ids
|
||||||
|
candidates, scores = self.generator.get_candidates(input_ids)
|
||||||
|
|
||||||
|
self.assertIsNotNone(candidates)
|
||||||
|
self.assertIsNotNone(scores)
|
||||||
|
self.assertTrue(torch.is_tensor(candidates))
|
||||||
|
self.assertTrue(torch.is_tensor(scores))
|
||||||
|
|
||||||
|
def test_mismatched_vocabularies(self):
|
||||||
|
"""Test handling of mismatched vocabularies between models"""
|
||||||
|
# Create input with tokens present in main but not assistant vocab
|
||||||
|
# Find a token that is not in the assistant tokenizer but in
|
||||||
|
# the main tokenizer.
|
||||||
|
missing_token = next(
|
||||||
|
token
|
||||||
|
for token in self.target_tokenizer.get_vocab()
|
||||||
|
if token not in self.assistant_tokenizer.get_vocab()
|
||||||
|
and token not in self.target_tokenizer.all_special_tokens
|
||||||
|
and "reserved_" not in token
|
||||||
|
)
|
||||||
|
input_ids = torch.tensor([[self.target_tokenizer.convert_tokens_to_ids(missing_token)]])
|
||||||
|
self.generator.input_ids = input_ids
|
||||||
|
candidates, scores = self.generator.get_candidates(input_ids)
|
||||||
|
self.assertIsNotNone(candidates)
|
||||||
|
|
||||||
|
def test_speculation_depth(self):
|
||||||
|
"""Test different speculation depths"""
|
||||||
|
input_ids = self.target_tokenizer.encode("Test text", return_tensors="pt")
|
||||||
|
self.generator.input_ids = input_ids
|
||||||
|
|
||||||
|
for depth in [1, 8, 17]:
|
||||||
|
self.generator.num_assistant_tokens = depth
|
||||||
|
candidates, scores = self.generator.get_candidates(input_ids)
|
||||||
|
self.assertLessEqual(candidates.shape[1] - input_ids.shape[1], depth)
|
||||||
|
|
||||||
|
def test_device_consistency(self):
|
||||||
|
"""Test handling of inputs on different devices"""
|
||||||
|
input_ids = torch.tensor([[1, 2, 3]]).to(torch_device)
|
||||||
|
self.generator.input_ids = input_ids
|
||||||
|
candidates, _ = self.generator.get_candidates(input_ids)
|
||||||
|
self.assertEqual(candidates.device, input_ids.device)
|
||||||
|
|
||||||
|
def test_usd_vs_vanilla_sampling(cls):
|
||||||
|
"""Test that USD matches vanilla sampling with temperature set to nearly 0"""
|
||||||
|
prompt = "Test text"
|
||||||
|
|
||||||
|
pipe_usd = pipeline("text-generation", model=cls.target_name, assistant_model=cls.assistant_name)
|
||||||
|
pipe_usd_output = pipe_usd(prompt, max_new_tokens=5, do_sample=True, temperature=1e-9) # Nearly 0 temperature
|
||||||
|
usd_text = pipe_usd_output[0]["generated_text"]
|
||||||
|
|
||||||
|
pipe_vanilla = pipeline(
|
||||||
|
"text-generation",
|
||||||
|
model=cls.target_name,
|
||||||
|
)
|
||||||
|
pipe_vanilla_output = pipe_vanilla(prompt, max_new_tokens=5, do_sample=False)
|
||||||
|
vanilla_text = pipe_vanilla_output[0]["generated_text"]
|
||||||
|
|
||||||
|
# Assert that the outputs match
|
||||||
|
cls.assertEqual(usd_text, vanilla_text)
|
||||||
|
|||||||
Reference in New Issue
Block a user