Fix _speculative_sampling implementation (#28508)

This commit is contained in:
Ofir Zafrir
2024-01-19 16:07:31 +02:00
committed by GitHub
parent d15781597a
commit 9efec11400
3 changed files with 70 additions and 22 deletions

View File

@@ -88,6 +88,7 @@ if is_torch_available():
TopKLogitsWarper,
TopPLogitsWarper,
)
from transformers.generation.utils import _speculative_sampling
class GenerationTesterMixin:
@@ -2424,6 +2425,43 @@ class UtilsFunctionsTest(unittest.TestCase):
self.assertTrue(torch.allclose(expected_output, output, atol=1e-12))
def test_speculative_sampling(self):
# assume vocab size 10, input length 5 + 3 generated candidates
candidate_input_ids = torch.tensor([[8, 0, 3, 9, 8, 1, 4, 5]]) # input tokens
candidate_logits = torch.tensor(
[
[
[-10.0, 10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0], # generated 1
[-10.0, -10.0, -10.0, -10.0, 10.0, -10.0, -10.0, -10.0, -10.0, -10.0], # generated 4
[-10.0, -10.0, -10.0, -10.0, -10.0, 10.0, -10.0, -10.0, -10.0, -10.0], # generated 5
]
]
)
candidate_length = 3
inf = float("inf")
new_logits = torch.tensor(
[
[
[-10.0, 10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0], # accepts 1
[-10.0, -10.0, -10.0, -10.0, 10.0, -10.0, -10.0, -10.0, -10.0, -10.0], # accepts 4
[-inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, 10.0, -inf], # rejects 5, accepts 8
[-10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0], # N/A
]
]
)
last_assistant_token_is_eos = False
max_matches = 5
validated_tokens, n_matches = _speculative_sampling(
candidate_input_ids,
candidate_logits,
candidate_length,
new_logits,
last_assistant_token_is_eos,
max_matches,
)
self.assertTrue(n_matches.item() == 2)
self.assertTrue(validated_tokens.tolist()[0] == [1, 4, 8])
@require_torch
class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMixin):