* Update test_utils.py * formatting * Update test_utils.py * formatting * formatting * Update test_utils.py * formatting * Update test_utils.py * formatting * format * comments at standard positions
This commit is contained in:
@@ -14,6 +14,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import collections
|
||||||
import copy
|
import copy
|
||||||
import gc
|
import gc
|
||||||
import inspect
|
import inspect
|
||||||
@@ -2450,6 +2451,58 @@ class UtilsFunctionsTest(unittest.TestCase):
|
|||||||
self.assertTrue(n_matches.item() == 2)
|
self.assertTrue(n_matches.item() == 2)
|
||||||
self.assertTrue(validated_tokens.tolist()[0] == [1, 4, 8])
|
self.assertTrue(validated_tokens.tolist()[0] == [1, 4, 8])
|
||||||
|
|
||||||
|
def test_speculative_sampling_target_distribution(self):
|
||||||
|
"""
|
||||||
|
Asserts that the target distribution is preserved.
|
||||||
|
Should help with catching issues like #32867.
|
||||||
|
"""
|
||||||
|
# assume vocab size 10, input length 5 + 3 generated candidates
|
||||||
|
candidate_input_ids = torch.tensor([[8, 0, 3, 9, 8, 1, 4, 5]]) # input tokens
|
||||||
|
candidate_logits = torch.tensor(
|
||||||
|
[
|
||||||
|
[
|
||||||
|
[-10.0, 10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0], # generated 1
|
||||||
|
[-10.0, -10.0, -10.0, -10.0, 10.0, -10.0, -10.0, -10.0, -10.0, -10.0], # generated 4
|
||||||
|
[-10.0, -10.0, -10.0, -10.0, -10.0, 10.0, -10.0, -10.0, -10.0, -10.0], # generated 5
|
||||||
|
]
|
||||||
|
]
|
||||||
|
)
|
||||||
|
candidate_length = 3
|
||||||
|
inf = float("inf")
|
||||||
|
new_logits = torch.tensor(
|
||||||
|
[
|
||||||
|
[
|
||||||
|
# accepts 1:
|
||||||
|
[-inf, 10.0, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
|
||||||
|
# accepts 4:
|
||||||
|
[-inf, -inf, -inf, -inf, 10.0, -inf, -inf, -inf, -inf, -inf],
|
||||||
|
# most likely to be 1 or 8, less likely to be 3, then 7, and should never be any other value:
|
||||||
|
[-inf, 2.0, -inf, 1.0, -inf, -inf, -inf, -0.01, 2.0, -inf],
|
||||||
|
# N/A:
|
||||||
|
[-inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
|
||||||
|
]
|
||||||
|
]
|
||||||
|
)
|
||||||
|
last_assistant_token_is_eos = False
|
||||||
|
last_validated_token = []
|
||||||
|
for _ in range(10_000):
|
||||||
|
validated_tokens, n_matches = _speculative_sampling(
|
||||||
|
candidate_input_ids,
|
||||||
|
candidate_logits,
|
||||||
|
candidate_length,
|
||||||
|
new_logits,
|
||||||
|
last_assistant_token_is_eos,
|
||||||
|
)
|
||||||
|
self.assertTrue(n_matches.item() == 2)
|
||||||
|
self.assertTrue(validated_tokens.tolist()[0][0] == 1)
|
||||||
|
self.assertTrue(validated_tokens.tolist()[0][1] == 4)
|
||||||
|
self.assertTrue(validated_tokens.tolist()[0][2] in [1, 3, 7, 8])
|
||||||
|
last_validated_token.append(validated_tokens.tolist()[0][2])
|
||||||
|
# check that the most likely tokens are selected more often than the less likely ones
|
||||||
|
last_token_counts = collections.Counter(last_validated_token)
|
||||||
|
self.assertTrue(last_token_counts[1] > last_token_counts[3] > last_token_counts[7] > 0)
|
||||||
|
self.assertTrue(last_token_counts[8] > last_token_counts[3])
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.generate
|
@pytest.mark.generate
|
||||||
@require_torch
|
@require_torch
|
||||||
|
|||||||
Reference in New Issue
Block a user