Fix _speculative_sampling implementation (#28508)
This commit is contained in:
@@ -88,6 +88,7 @@ if is_torch_available():
|
||||
TopKLogitsWarper,
|
||||
TopPLogitsWarper,
|
||||
)
|
||||
from transformers.generation.utils import _speculative_sampling
|
||||
|
||||
|
||||
class GenerationTesterMixin:
|
||||
@@ -2424,6 +2425,43 @@ class UtilsFunctionsTest(unittest.TestCase):
|
||||
|
||||
self.assertTrue(torch.allclose(expected_output, output, atol=1e-12))
|
||||
|
||||
def test_speculative_sampling(self):
|
||||
# assume vocab size 10, input length 5 + 3 generated candidates
|
||||
candidate_input_ids = torch.tensor([[8, 0, 3, 9, 8, 1, 4, 5]]) # input tokens
|
||||
candidate_logits = torch.tensor(
|
||||
[
|
||||
[
|
||||
[-10.0, 10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0], # generated 1
|
||||
[-10.0, -10.0, -10.0, -10.0, 10.0, -10.0, -10.0, -10.0, -10.0, -10.0], # generated 4
|
||||
[-10.0, -10.0, -10.0, -10.0, -10.0, 10.0, -10.0, -10.0, -10.0, -10.0], # generated 5
|
||||
]
|
||||
]
|
||||
)
|
||||
candidate_length = 3
|
||||
inf = float("inf")
|
||||
new_logits = torch.tensor(
|
||||
[
|
||||
[
|
||||
[-10.0, 10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0], # accepts 1
|
||||
[-10.0, -10.0, -10.0, -10.0, 10.0, -10.0, -10.0, -10.0, -10.0, -10.0], # accepts 4
|
||||
[-inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, 10.0, -inf], # rejects 5, accepts 8
|
||||
[-10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0], # N/A
|
||||
]
|
||||
]
|
||||
)
|
||||
last_assistant_token_is_eos = False
|
||||
max_matches = 5
|
||||
validated_tokens, n_matches = _speculative_sampling(
|
||||
candidate_input_ids,
|
||||
candidate_logits,
|
||||
candidate_length,
|
||||
new_logits,
|
||||
last_assistant_token_is_eos,
|
||||
max_matches,
|
||||
)
|
||||
self.assertTrue(n_matches.item() == 2)
|
||||
self.assertTrue(validated_tokens.tolist()[0] == [1, 4, 8])
|
||||
|
||||
|
||||
@require_torch
|
||||
class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMixin):
|
||||
|
||||
Reference in New Issue
Block a user