Fix StopStringCriteria to handle tokens above len(tokenizer) (#35797)

* Fix StopStringCriteria to handle tokens above len(tokenizer)

This fixes #35244 by clipping token IDs to be within the tokenizer's vocabulary size before performing the embedding lookup. This prevents index errors when model.config.vocab_size > len(tokenizer).

The fix:
1. Adds a clamp operation to ensure token IDs are within bounds
2. Adds a test case to verify the behavior

* Use self.stop_strings instead of stop_strings

* Handle clipping correctly

* make fixup

* Update test to the new embedding vecs

* Use much bigger values in the mismatch test

* Typo fix

* Slight simplification

---------

Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
Matt
2025-02-06 16:53:28 +00:00
committed by GitHub
parent 28f73bc307
commit 4563ba2c6f
2 changed files with 27 additions and 10 deletions

View File

@@ -176,6 +176,18 @@ class StoppingCriteriaTestCase(unittest.TestCase):
for i in range(len(false_strings)):
self.assertFalse(criteria(false_input_ids["input_ids"][i : i + 1], scores))
def test_stop_string_criteria_vocab_size_mismatch(self):
"""Test that StopStringCriteria handles tokens above len(tokenizer) correctly."""
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
# Create input_ids with tokens above len(tokenizer)
input_ids = torch.tensor([[len(tokenizer) + 1024, 1, 2]], device=torch_device)
scores = None
criteria = StopStringCriteria(tokenizer=tokenizer, stop_strings=["test"])
# This should not raise an error and should return False since no stop string is matched
self.assertFalse(criteria(input_ids, scores))
def test_stop_string_matching_positions(self):
stop_string = "stop"
token_list = ["last", "top", "topper", "s", "p"]
@@ -200,14 +212,14 @@ class StoppingCriteriaTestCase(unittest.TestCase):
# Positions inside the stop string where the token matches (excluding end overlaps)
valid_positions = embedding_vec[:, 0].tolist()
self.assertEqual(valid_positions, [2, -1, -1, 3, -1])
self.assertEqual(valid_positions, [2, -1, -1, 3, -1, -1])
# Overlap lengths between end of stop string and start of token
end_overlaps = embedding_vec[:, 1].tolist()
self.assertEqual(end_overlaps, [-1, 3, 3, -1, 1])
self.assertEqual(end_overlaps, [-1, 3, 3, -1, 1, -1])
# Length of each token
token_lengths = embedding_vec[:, 2].tolist()
token_lengths = embedding_vec[:-1, 2].tolist()
self.assertEqual(token_lengths, [len(token) for token in token_list])
def test_single_letter_stop_string(self):