Refactoring AssistedCandidateGenerator for Improved Modularity and Reusability (#35009)
* move `TestAssistedCandidateGeneratorDifferentTokenizers` into a new testing file * refactor * NOTHING. add space to rerun github actions tests * remove it... * NOTHING. add space to rerun github actions tests * remove it... * replace: `self.prev_tokens` -> `self.prev_assistant_ids` * NOTHING. rerun CI tests * remove it * introduce `self.prev_target_ids_len` * fix style * fix style --------- Co-authored-by: Jonathan Mamou <jonathan.mamou@intel.com>
This commit is contained in:
@@ -208,56 +208,15 @@ class AssistedCandidateGenerator(CandidateGenerator):
|
|||||||
vocabulary_size)` containing the logits associated to each candidate.
|
vocabulary_size)` containing the logits associated to each candidate.
|
||||||
"""
|
"""
|
||||||
input_ids = input_ids.to(self.assistant_model.device)
|
input_ids = input_ids.to(self.assistant_model.device)
|
||||||
|
# Calculate new tokens to generate
|
||||||
# Don't generate more than `max_length - 1` candidates since the target model generates one extra token.
|
min_new_tokens, max_new_tokens = self._calculate_new_tokens(input_ids)
|
||||||
new_cur_len = input_ids.shape[-1]
|
|
||||||
max_new_tokens = min(int(self.num_assistant_tokens), self.generation_config.max_length - new_cur_len - 1)
|
|
||||||
min_new_tokens = max(min(max_new_tokens, self.main_model_min_length - new_cur_len), 0)
|
|
||||||
if max_new_tokens == 0:
|
if max_new_tokens == 0:
|
||||||
return input_ids, None
|
return input_ids, None
|
||||||
|
# Update past key values and masks
|
||||||
# 1. If it is not the first round of candidate generation, prepare the inputs based on the input_ids length
|
self._update_past_and_masks(input_ids)
|
||||||
# (which implicitly contains the number of accepted candidates from the previous round)
|
# Generate candidates
|
||||||
has_past_key_values = self.assistant_kwargs.get("past_key_values", None) is not None
|
generation_args = self._prepare_generation_args(input_ids, min_new_tokens, max_new_tokens)
|
||||||
if has_past_key_values:
|
candidate_ids, candidate_logits = self._generate_candidates(generation_args)
|
||||||
new_cache_size = new_cur_len - 1
|
|
||||||
self.assistant_kwargs["past_key_values"] = _crop_past_key_values(
|
|
||||||
self.assistant_model, self.assistant_kwargs["past_key_values"], new_cache_size - 1
|
|
||||||
) # the assistant does not have the token after the last match, hence the -1
|
|
||||||
|
|
||||||
self.assistant_kwargs = _prepare_attention_mask(
|
|
||||||
self.assistant_kwargs, new_cur_len, self.assistant_model.config.is_encoder_decoder
|
|
||||||
)
|
|
||||||
self.assistant_kwargs = _prepare_token_type_ids(self.assistant_kwargs, new_cur_len)
|
|
||||||
|
|
||||||
# 2. Forecast next N tokens using the assistant model.
|
|
||||||
assistant_generation_kwargs = {
|
|
||||||
self.input_ids_key: input_ids,
|
|
||||||
"min_new_tokens": min_new_tokens,
|
|
||||||
"max_new_tokens": max_new_tokens,
|
|
||||||
"generation_config": self.generation_config,
|
|
||||||
"logits_processor": self.logits_processor,
|
|
||||||
}
|
|
||||||
|
|
||||||
assistant_output = self.assistant_model.generate(**assistant_generation_kwargs, **self.assistant_kwargs)
|
|
||||||
|
|
||||||
# 3. Update variables for the next round of candidate generation
|
|
||||||
self.assistant_kwargs["past_key_values"] = assistant_output.past_key_values
|
|
||||||
|
|
||||||
if (
|
|
||||||
is_sklearn_available()
|
|
||||||
and self.assistant_model.generation_config.assistant_confidence_threshold
|
|
||||||
and type(self) is AssistedCandidateGenerator
|
|
||||||
):
|
|
||||||
scores_tensor = torch.cat(assistant_output.scores, dim=0)
|
|
||||||
scores_softmax = torch.softmax(scores_tensor, dim=-1)
|
|
||||||
ids = assistant_output.sequences[-1, -len(assistant_output.scores) :]
|
|
||||||
p = scores_softmax[range(len(ids)), ids]
|
|
||||||
self.probs.extend(p.tolist())
|
|
||||||
|
|
||||||
# 4. Prepare variables for output
|
|
||||||
candidate_logits = torch.stack(assistant_output.scores, dim=1)
|
|
||||||
candidate_ids = assistant_output.sequences
|
|
||||||
return candidate_ids, candidate_logits
|
return candidate_ids, candidate_logits
|
||||||
|
|
||||||
def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, num_matches: int):
|
def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, num_matches: int):
|
||||||
@@ -318,6 +277,55 @@ class AssistedCandidateGenerator(CandidateGenerator):
|
|||||||
|
|
||||||
self.assistant_model.generation_config.assistant_confidence_threshold = best_threshold
|
self.assistant_model.generation_config.assistant_confidence_threshold = best_threshold
|
||||||
|
|
||||||
|
def _calculate_new_tokens(self, input_ids: torch.LongTensor) -> Tuple[int, int]:
|
||||||
|
"""Calculate the minimum and maximum number of new tokens to generate."""
|
||||||
|
new_cur_len = input_ids.shape[-1]
|
||||||
|
max_new_tokens = min(int(self.num_assistant_tokens), self.generation_config.max_length - new_cur_len - 1)
|
||||||
|
min_new_tokens = max(min(max_new_tokens, self.main_model_min_length - new_cur_len), 0)
|
||||||
|
return min_new_tokens, max_new_tokens
|
||||||
|
|
||||||
|
def _update_past_and_masks(self, input_ids: torch.LongTensor, remove_from_pkv: int = 0) -> bool:
|
||||||
|
"""Update past key values and attention masks for subsequent generation rounds."""
|
||||||
|
has_past_key_values = self.assistant_kwargs.get("past_key_values", None) is not None
|
||||||
|
if has_past_key_values:
|
||||||
|
new_cache_size = input_ids.shape[-1] - 1 - remove_from_pkv
|
||||||
|
self.assistant_kwargs["past_key_values"] = _crop_past_key_values(
|
||||||
|
self.assistant_model, self.assistant_kwargs["past_key_values"], new_cache_size - 1
|
||||||
|
)
|
||||||
|
self.assistant_kwargs = _prepare_attention_mask(
|
||||||
|
self.assistant_kwargs, input_ids.shape[-1], self.assistant_model.config.is_encoder_decoder
|
||||||
|
)
|
||||||
|
self.assistant_kwargs = _prepare_token_type_ids(self.assistant_kwargs, input_ids.shape[-1])
|
||||||
|
return has_past_key_values
|
||||||
|
|
||||||
|
def _prepare_generation_args(self, input_ids: torch.LongTensor, min_new_tokens: int, max_new_tokens: int) -> Dict:
|
||||||
|
"""Prepare arguments for the generation call."""
|
||||||
|
return {
|
||||||
|
self.input_ids_key: input_ids,
|
||||||
|
"min_new_tokens": min_new_tokens,
|
||||||
|
"max_new_tokens": max_new_tokens,
|
||||||
|
"generation_config": self.generation_config,
|
||||||
|
"logits_processor": self.logits_processor,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _generate_candidates(self, generation_args: Dict) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]:
|
||||||
|
"""Generate candidate sequences using the assistant model."""
|
||||||
|
assistant_output = self.assistant_model.generate(**generation_args, **self.assistant_kwargs)
|
||||||
|
self.assistant_kwargs["past_key_values"] = assistant_output.past_key_values
|
||||||
|
if (
|
||||||
|
is_sklearn_available()
|
||||||
|
and self.assistant_model.generation_config.assistant_confidence_threshold
|
||||||
|
and type(self) is AssistedCandidateGenerator
|
||||||
|
):
|
||||||
|
scores_tensor = torch.cat(assistant_output.scores, dim=0)
|
||||||
|
scores_softmax = torch.softmax(scores_tensor, dim=-1)
|
||||||
|
ids = assistant_output.sequences[-1, -len(assistant_output.scores) :]
|
||||||
|
p = scores_softmax[range(len(ids)), ids]
|
||||||
|
self.probs.extend(p.tolist())
|
||||||
|
candidate_logits = torch.stack(assistant_output.scores, dim=1)
|
||||||
|
candidate_ids = assistant_output.sequences
|
||||||
|
return candidate_ids, candidate_logits
|
||||||
|
|
||||||
|
|
||||||
class AssistedCandidateGeneratorDifferentTokenizers(AssistedCandidateGenerator):
|
class AssistedCandidateGeneratorDifferentTokenizers(AssistedCandidateGenerator):
|
||||||
"""
|
"""
|
||||||
@@ -367,6 +375,7 @@ class AssistedCandidateGeneratorDifferentTokenizers(AssistedCandidateGenerator):
|
|||||||
|
|
||||||
self.target_tokenizer = target_tokenizer
|
self.target_tokenizer = target_tokenizer
|
||||||
self.assistant_tokenizer = assistant_tokenizer
|
self.assistant_tokenizer = assistant_tokenizer
|
||||||
|
self.prev_target_ids_len: Optional[int] = None
|
||||||
self.prev_assistant_ids = None
|
self.prev_assistant_ids = None
|
||||||
self.target_lookbehind = assistant_model.generation_config.target_lookbehind
|
self.target_lookbehind = assistant_model.generation_config.target_lookbehind
|
||||||
self.assistant_lookbehind = assistant_model.generation_config.assistant_lookbehind
|
self.assistant_lookbehind = assistant_model.generation_config.assistant_lookbehind
|
||||||
@@ -497,18 +506,41 @@ class AssistedCandidateGeneratorDifferentTokenizers(AssistedCandidateGenerator):
|
|||||||
return input_ids, None
|
return input_ids, None
|
||||||
|
|
||||||
input_ids = input_ids.to(self.assistant_model.device)
|
input_ids = input_ids.to(self.assistant_model.device)
|
||||||
|
remove_from_pkv = 0
|
||||||
|
|
||||||
|
assistant_input_ids, remove_from_pkv = self._prepare_assistant_input_ids(input_ids)
|
||||||
|
self.prev_assistant_ids = assistant_input_ids
|
||||||
|
|
||||||
|
min_new_tokens = max(min(max_new_tokens, self.main_model_min_length - assistant_input_ids.shape[-1]), 0)
|
||||||
|
|
||||||
|
self._update_past_and_masks(assistant_input_ids, remove_from_pkv)
|
||||||
|
generation_args = self._prepare_generation_args(assistant_input_ids, min_new_tokens, max_new_tokens)
|
||||||
|
self.assistant_kwargs.pop("attention_mask", None)
|
||||||
|
|
||||||
|
assistant_output = self.assistant_model.generate(**generation_args, **self.assistant_kwargs)
|
||||||
|
new_target_ids = self._process_assistant_outputs(input_ids, assistant_output.sequences, assistant_input_ids)
|
||||||
|
|
||||||
|
# Update state
|
||||||
|
self.prev_target_ids_len = input_ids.shape[1]
|
||||||
|
self.assistant_kwargs["past_key_values"] = assistant_output.past_key_values
|
||||||
|
self.prev_assistant_ids = assistant_output.sequences
|
||||||
|
|
||||||
|
if self.prev_target_ids_len >= new_target_ids.shape[1]:
|
||||||
|
return input_ids, None
|
||||||
|
|
||||||
|
return new_target_ids, None
|
||||||
|
|
||||||
|
def _prepare_assistant_input_ids(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, int]:
|
||||||
|
"""Converts target input IDs to assistant input IDs, handling discrepancies."""
|
||||||
convert_kwargs = {
|
convert_kwargs = {
|
||||||
"source_tokenizer": self.target_tokenizer,
|
"source_tokenizer": self.target_tokenizer,
|
||||||
"destination_tokenizer": self.assistant_tokenizer,
|
"destination_tokenizer": self.assistant_tokenizer,
|
||||||
}
|
}
|
||||||
remove_from_pkv = 0
|
remove_from_pkv = 0
|
||||||
|
|
||||||
# Since re-encoding the tokens may result in tokenization discrepancies, we use 2 look behind values
|
if self.prev_assistant_ids is not None and self.prev_target_ids_len > self.target_lookbehind:
|
||||||
# (one for each conversion) which mark where to start looking for the overlap between the
|
|
||||||
# source and target encodings, to ensure the new tokens include the correct prompt suffix.
|
|
||||||
if self.prev_assistant_ids is not None and input_ids.shape[1] > self.target_lookbehind:
|
|
||||||
# input_ids contains all target prompt input ids and some new target input ids
|
# input_ids contains all target prompt input ids and some new target input ids
|
||||||
start_index_in_target_window = input_ids.shape[1] - self.target_lookbehind
|
start_index_in_target_window = self.prev_target_ids_len - self.target_lookbehind
|
||||||
|
|
||||||
new_assistant_ids = self.convert_source_tokens_to_target_tokens(
|
new_assistant_ids = self.convert_source_tokens_to_target_tokens(
|
||||||
input_ids[:, start_index_in_target_window:], **convert_kwargs
|
input_ids[:, start_index_in_target_window:], **convert_kwargs
|
||||||
@@ -516,8 +548,8 @@ class AssistedCandidateGeneratorDifferentTokenizers(AssistedCandidateGenerator):
|
|||||||
prompt_use_length = new_assistant_ids.shape[1]
|
prompt_use_length = new_assistant_ids.shape[1]
|
||||||
prompt_use = self.prev_assistant_ids[:, -prompt_use_length:]
|
prompt_use = self.prev_assistant_ids[:, -prompt_use_length:]
|
||||||
|
|
||||||
discrepancy_length, new_tokens_only, discrepancy_only = (
|
discrepancy_length, new_tokens_only, discrepancy_only = self._get_tokens_diag(
|
||||||
AssistedCandidateGeneratorDifferentTokenizers._get_tokens_diag(prompt_use, new_assistant_ids)
|
prompt_use, new_assistant_ids
|
||||||
)
|
)
|
||||||
assistant_input_ids = self.prev_assistant_ids
|
assistant_input_ids = self.prev_assistant_ids
|
||||||
|
|
||||||
@@ -538,48 +570,21 @@ class AssistedCandidateGeneratorDifferentTokenizers(AssistedCandidateGenerator):
|
|||||||
else:
|
else:
|
||||||
# edge case: in case of no intersection between prompt and new_assistant_ids
|
# edge case: in case of no intersection between prompt and new_assistant_ids
|
||||||
assistant_input_ids = torch.cat([assistant_input_ids, new_assistant_ids], dim=-1)
|
assistant_input_ids = torch.cat([assistant_input_ids, new_assistant_ids], dim=-1)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
assistant_input_ids = self.convert_source_tokens_to_target_tokens(input_ids, **convert_kwargs)
|
assistant_input_ids = self.convert_source_tokens_to_target_tokens(input_ids, **convert_kwargs)
|
||||||
|
self.prev_target_ids_len = input_ids.shape[1]
|
||||||
|
|
||||||
self.prev_assistant_ids = assistant_input_ids
|
return assistant_input_ids, remove_from_pkv
|
||||||
new_cur_len = assistant_input_ids.shape[-1]
|
|
||||||
min_new_tokens = max(min(max_new_tokens, self.main_model_min_length - new_cur_len), 0)
|
|
||||||
|
|
||||||
# 1. If it is not the first round of candidate generation, prepare the inputs based on the input_ids length
|
|
||||||
# (which implicitly contains the number of accepted candidates from the previous round)
|
|
||||||
has_past_key_values = self.assistant_kwargs.get("past_key_values", None) is not None
|
|
||||||
if has_past_key_values:
|
|
||||||
new_cache_size = new_cur_len - 1 - remove_from_pkv
|
|
||||||
self.assistant_kwargs["past_key_values"] = _crop_past_key_values(
|
|
||||||
self.assistant_model, self.assistant_kwargs["past_key_values"], new_cache_size - 1
|
|
||||||
) # the assistant does not have the token after the last match, hence the -1
|
|
||||||
|
|
||||||
self.assistant_kwargs = _prepare_attention_mask(
|
|
||||||
self.assistant_kwargs, new_cur_len, self.assistant_model.config.is_encoder_decoder
|
|
||||||
)
|
|
||||||
self.assistant_kwargs = _prepare_token_type_ids(self.assistant_kwargs, new_cur_len)
|
|
||||||
|
|
||||||
# 2. Forecast next N tokens using the assistant model.
|
|
||||||
assistant_generation_kwargs = {
|
|
||||||
self.input_ids_key: assistant_input_ids,
|
|
||||||
"min_new_tokens": min_new_tokens,
|
|
||||||
"max_new_tokens": max_new_tokens,
|
|
||||||
"generation_config": self.generation_config,
|
|
||||||
"logits_processor": self.logits_processor,
|
|
||||||
}
|
|
||||||
|
|
||||||
self.assistant_kwargs.pop("attention_mask", None)
|
|
||||||
|
|
||||||
assistant_output = self.assistant_model.generate(**assistant_generation_kwargs, **self.assistant_kwargs)
|
|
||||||
|
|
||||||
|
def _process_assistant_outputs(
|
||||||
|
self, input_ids: torch.LongTensor, assistant_sequences: torch.LongTensor, assistant_input_ids: torch.LongTensor
|
||||||
|
) -> torch.LongTensor:
|
||||||
|
"""Processes assistant outputs to obtain target input IDs."""
|
||||||
num_prev_assistant = self.prev_assistant_ids.shape[1]
|
num_prev_assistant = self.prev_assistant_ids.shape[1]
|
||||||
start_assistant_look_index = num_prev_assistant - self.assistant_lookbehind
|
start_assistant_look_index = num_prev_assistant - self.assistant_lookbehind
|
||||||
if start_assistant_look_index < 0:
|
|
||||||
start_assistant_look_index = 0
|
|
||||||
|
|
||||||
new_target_ids_from_window = self.convert_source_tokens_to_target_tokens(
|
new_target_ids_from_window = self.convert_source_tokens_to_target_tokens(
|
||||||
assistant_output.sequences[:, start_assistant_look_index:],
|
assistant_sequences[:, start_assistant_look_index:],
|
||||||
source_tokenizer=self.assistant_tokenizer,
|
source_tokenizer=self.assistant_tokenizer,
|
||||||
destination_tokenizer=self.target_tokenizer,
|
destination_tokenizer=self.target_tokenizer,
|
||||||
)
|
)
|
||||||
@@ -587,9 +592,7 @@ class AssistedCandidateGeneratorDifferentTokenizers(AssistedCandidateGenerator):
|
|||||||
|
|
||||||
target_prompt_use = input_ids[:, -target_prompt_use_length:]
|
target_prompt_use = input_ids[:, -target_prompt_use_length:]
|
||||||
|
|
||||||
_, target_new_tokens_only, _ = AssistedCandidateGeneratorDifferentTokenizers._get_tokens_diag(
|
_, target_new_tokens_only, _ = self._get_tokens_diag(target_prompt_use, new_target_ids_from_window)
|
||||||
target_prompt_use, new_target_ids_from_window
|
|
||||||
)
|
|
||||||
|
|
||||||
new_target_ids = input_ids
|
new_target_ids = input_ids
|
||||||
|
|
||||||
@@ -603,14 +606,7 @@ class AssistedCandidateGeneratorDifferentTokenizers(AssistedCandidateGenerator):
|
|||||||
if hasattr(self.generation_config, "max_length"):
|
if hasattr(self.generation_config, "max_length"):
|
||||||
new_target_ids = new_target_ids[:, : self.generation_config.max_length]
|
new_target_ids = new_target_ids[:, : self.generation_config.max_length]
|
||||||
|
|
||||||
# 3. Update variables for the next round of candidate generation
|
return new_target_ids
|
||||||
self.assistant_kwargs["past_key_values"] = assistant_output.past_key_values
|
|
||||||
|
|
||||||
# 4. Prepare variables for output
|
|
||||||
if input_ids.shape[1] >= new_target_ids.shape[1]:
|
|
||||||
return input_ids, None
|
|
||||||
|
|
||||||
return new_target_ids, None
|
|
||||||
|
|
||||||
|
|
||||||
class PromptLookupCandidateGenerator(CandidateGenerator):
|
class PromptLookupCandidateGenerator(CandidateGenerator):
|
||||||
|
|||||||
43
tests/generation/test_candidate_generator.py
Normal file
43
tests/generation/test_candidate_generator.py
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
import unittest
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from transformers.generation.candidate_generator import AssistedCandidateGeneratorDifferentTokenizers
|
||||||
|
|
||||||
|
|
||||||
|
class TestAssistedCandidateGeneratorDifferentTokenizers(unittest.TestCase):
|
||||||
|
def test_no_intersection(self):
|
||||||
|
prompt = np.array([[1, 2, 3]])
|
||||||
|
prompt_plus_new_tokens = np.array([[4, 5, 6]])
|
||||||
|
result = AssistedCandidateGeneratorDifferentTokenizers._get_tokens_diag(prompt, prompt_plus_new_tokens)
|
||||||
|
self.assertEqual(result, (None, None, None))
|
||||||
|
|
||||||
|
def test_complete_overlap(self):
|
||||||
|
prompt = np.array([[1, 2, 3]])
|
||||||
|
prompt_plus_new_tokens = np.array([[1, 2, 3, 4, 5]])
|
||||||
|
discrep_length, new_tokens_only, discrep_only = AssistedCandidateGeneratorDifferentTokenizers._get_tokens_diag(
|
||||||
|
prompt, prompt_plus_new_tokens
|
||||||
|
)
|
||||||
|
self.assertEqual(discrep_length, 0)
|
||||||
|
np.testing.assert_array_equal(new_tokens_only, np.array([[4, 5]]))
|
||||||
|
np.testing.assert_array_equal(discrep_only, np.array([[]]))
|
||||||
|
|
||||||
|
def test_partial_overlap(self):
|
||||||
|
prompt = np.array([[1, 2, 3]])
|
||||||
|
prompt_plus_new_tokens = np.array([[2, 3, 4, 5]])
|
||||||
|
discrep_length, new_tokens_only, discrep_only = AssistedCandidateGeneratorDifferentTokenizers._get_tokens_diag(
|
||||||
|
prompt, prompt_plus_new_tokens
|
||||||
|
)
|
||||||
|
self.assertEqual(discrep_length, 0)
|
||||||
|
np.testing.assert_array_equal(new_tokens_only, np.array([[4, 5]]))
|
||||||
|
np.testing.assert_array_equal(discrep_only, np.array([[]]))
|
||||||
|
|
||||||
|
def test_no_new_tokens(self):
|
||||||
|
prompt = np.array([[1, 2, 3]])
|
||||||
|
prompt_plus_new_tokens = np.array([[1, 2, 3]])
|
||||||
|
discrep_length, new_tokens_only, discrep_only = AssistedCandidateGeneratorDifferentTokenizers._get_tokens_diag(
|
||||||
|
prompt, prompt_plus_new_tokens
|
||||||
|
)
|
||||||
|
self.assertEqual(discrep_length, 0)
|
||||||
|
np.testing.assert_array_equal(new_tokens_only, np.array([[]]))
|
||||||
|
np.testing.assert_array_equal(discrep_only, np.array([[]]))
|
||||||
Reference in New Issue
Block a user