Fix _speculative_sampling implementation (#28508)
This commit is contained in:
@@ -171,12 +171,16 @@ class AssistedCandidateGenerator(CandidateGenerator):
|
||||
"""
|
||||
input_ids = input_ids.to(self.assistant_model.device)
|
||||
|
||||
# Don't generate more than `max_length - 1` candidates since the target model generates one extra token.
|
||||
new_cur_len = input_ids.shape[-1]
|
||||
max_new_tokens = min(int(self.num_assistant_tokens), self.generation_config.max_length - new_cur_len - 1)
|
||||
if max_new_tokens == 0:
|
||||
return input_ids, None
|
||||
|
||||
# 1. If it is not the first round of candidate generation, prepare the inputs based on the input_ids length
|
||||
# (which implicitly contains the number of accepted candidates from the previous round)
|
||||
has_past_key_values = self.assistant_kwargs.get("past_key_values", None) is not None
|
||||
if has_past_key_values:
|
||||
new_cur_len = input_ids.shape[-1]
|
||||
|
||||
new_cache_size = new_cur_len - 1
|
||||
self.assistant_kwargs["past_key_values"] = _crop_past_key_values(
|
||||
self.assistant_model, self.assistant_kwargs["past_key_values"], new_cache_size - 1
|
||||
@@ -190,7 +194,7 @@ class AssistedCandidateGenerator(CandidateGenerator):
|
||||
# 2. Forecast next N tokens using the assistant model.
|
||||
assistant_generation_kwargs = {
|
||||
self.input_ids_key: input_ids,
|
||||
"max_new_tokens": int(self.num_assistant_tokens),
|
||||
"max_new_tokens": max_new_tokens,
|
||||
"generation_config": self.generation_config,
|
||||
"logits_processor": self.logits_processor,
|
||||
}
|
||||
|
||||
@@ -4404,7 +4404,7 @@ class GenerationMixin:
|
||||
else:
|
||||
selected_tokens = new_logits.argmax(dim=-1)
|
||||
|
||||
candidate_new_tokens = candidate_input_ids[:, -candidate_length:]
|
||||
candidate_new_tokens = candidate_input_ids[:, cur_len:]
|
||||
n_matches = ((~(candidate_new_tokens == selected_tokens[:, :-1])).cumsum(dim=-1) < 1).sum()
|
||||
|
||||
# Ensure we don't generate beyond max_len or an EOS token
|
||||
@@ -4540,12 +4540,13 @@ def _speculative_sampling(
|
||||
|
||||
NOTE: Unless otherwise stated, the variable names match those in the paper.
|
||||
"""
|
||||
new_candidate_input_ids = candidate_input_ids[:, -candidate_length:]
|
||||
# Gets the probabilities from the logits. q_i and p_i denote the assistant and model probabilities of the tokens
|
||||
# selected by the assistant, respectively.
|
||||
q = candidate_logits.softmax(dim=-1)
|
||||
q_i = q[:, torch.arange(candidate_length), candidate_input_ids[:, -candidate_length:]].squeeze(0, 1)
|
||||
q_i = q[:, torch.arange(candidate_length), new_candidate_input_ids].squeeze(0, 1)
|
||||
p = new_logits.softmax(dim=-1)
|
||||
p_i = p[:, torch.arange(candidate_length), candidate_input_ids[:, -candidate_length:]].squeeze(0, 1)
|
||||
p_i = p[:, torch.arange(candidate_length), new_candidate_input_ids].squeeze(0, 1)
|
||||
probability_ratio = p_i / q_i
|
||||
|
||||
# When probability_ratio > 1 (i.e. q_i(x) < p_i(x), or "assistant probability of the candidate token is smaller
|
||||
@@ -4553,28 +4554,33 @@ def _speculative_sampling(
|
||||
# (= keep with p = probability_ratio). Keep all the tokens until the first rejection
|
||||
r_i = torch.rand_like(probability_ratio)
|
||||
is_accepted = r_i <= probability_ratio
|
||||
n_matches = (~is_accepted.cumsum(dim=-1) < 1).sum() # this is `n` in algorithm 1
|
||||
n_matches = ((~is_accepted).cumsum(dim=-1) < 1).sum() # this is `n` in algorithm 1
|
||||
|
||||
# Ensure we don't generate beyond max_len or an EOS token (not in algorithm 1, but needed for correct behavior)
|
||||
if last_assistant_token_is_eos and n_matches == candidate_length:
|
||||
# Output length is assumed to be `n_matches + 1`. Since we won't generate another token with the target model
|
||||
# due to acceptance on EOS we fix `n_matches`
|
||||
n_matches -= 1
|
||||
n_matches = min(n_matches, max_matches)
|
||||
|
||||
# Next token selection: if there is a rejection, adjust the distribution from the main model before sampling.
|
||||
gamma = candidate_logits.shape[1]
|
||||
p_n_plus_1 = p[:, n_matches, :]
|
||||
if n_matches < gamma:
|
||||
q_n_plus_1 = q[:, n_matches, :]
|
||||
p_prime = torch.clamp((p_n_plus_1 - q_n_plus_1), min=0).softmax(dim=-1)
|
||||
valid_tokens = new_candidate_input_ids[:, : n_matches + 1]
|
||||
else:
|
||||
p_prime = p_n_plus_1
|
||||
t = torch.multinomial(p_prime, num_samples=1).squeeze(1)[None, :]
|
||||
n_matches = min(n_matches, max_matches)
|
||||
|
||||
# The selected tokens include the matches (if any) plus the next sampled tokens
|
||||
if n_matches > 0:
|
||||
valid_tokens = torch.cat((candidate_input_ids[:, -n_matches:], t), dim=-1)
|
||||
else:
|
||||
valid_tokens = t
|
||||
# Next token selection: if there is a rejection, adjust the distribution from the main model before sampling.
|
||||
gamma = min(candidate_logits.shape[1], max_matches)
|
||||
p_n_plus_1 = p[:, n_matches, :]
|
||||
if n_matches < gamma:
|
||||
q_n_plus_1 = q[:, n_matches, :]
|
||||
p_prime = torch.clamp((p_n_plus_1 - q_n_plus_1), min=0)
|
||||
p_prime.div_(p_prime.sum())
|
||||
else:
|
||||
p_prime = p_n_plus_1
|
||||
t = torch.multinomial(p_prime, num_samples=1).squeeze(1)[None, :]
|
||||
|
||||
# The selected tokens include the matches (if any) plus the next sampled tokens
|
||||
if n_matches > 0:
|
||||
valid_tokens = torch.cat((new_candidate_input_ids[:, :n_matches], t), dim=-1)
|
||||
else:
|
||||
valid_tokens = t
|
||||
|
||||
return valid_tokens, n_matches
|
||||
|
||||
|
||||
@@ -88,6 +88,7 @@ if is_torch_available():
|
||||
TopKLogitsWarper,
|
||||
TopPLogitsWarper,
|
||||
)
|
||||
from transformers.generation.utils import _speculative_sampling
|
||||
|
||||
|
||||
class GenerationTesterMixin:
|
||||
@@ -2424,6 +2425,43 @@ class UtilsFunctionsTest(unittest.TestCase):
|
||||
|
||||
self.assertTrue(torch.allclose(expected_output, output, atol=1e-12))
|
||||
|
||||
def test_speculative_sampling(self):
|
||||
# assume vocab size 10, input length 5 + 3 generated candidates
|
||||
candidate_input_ids = torch.tensor([[8, 0, 3, 9, 8, 1, 4, 5]]) # input tokens
|
||||
candidate_logits = torch.tensor(
|
||||
[
|
||||
[
|
||||
[-10.0, 10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0], # generated 1
|
||||
[-10.0, -10.0, -10.0, -10.0, 10.0, -10.0, -10.0, -10.0, -10.0, -10.0], # generated 4
|
||||
[-10.0, -10.0, -10.0, -10.0, -10.0, 10.0, -10.0, -10.0, -10.0, -10.0], # generated 5
|
||||
]
|
||||
]
|
||||
)
|
||||
candidate_length = 3
|
||||
inf = float("inf")
|
||||
new_logits = torch.tensor(
|
||||
[
|
||||
[
|
||||
[-10.0, 10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0], # accepts 1
|
||||
[-10.0, -10.0, -10.0, -10.0, 10.0, -10.0, -10.0, -10.0, -10.0, -10.0], # accepts 4
|
||||
[-inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, 10.0, -inf], # rejects 5, accepts 8
|
||||
[-10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0], # N/A
|
||||
]
|
||||
]
|
||||
)
|
||||
last_assistant_token_is_eos = False
|
||||
max_matches = 5
|
||||
validated_tokens, n_matches = _speculative_sampling(
|
||||
candidate_input_ids,
|
||||
candidate_logits,
|
||||
candidate_length,
|
||||
new_logits,
|
||||
last_assistant_token_is_eos,
|
||||
max_matches,
|
||||
)
|
||||
self.assertTrue(n_matches.item() == 2)
|
||||
self.assertTrue(validated_tokens.tolist()[0] == [1, 4, 8])
|
||||
|
||||
|
||||
@require_torch
|
||||
class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMixin):
|
||||
|
||||
Reference in New Issue
Block a user