Fix StopStringCriteria to handle tokens above len(tokenizer) (#35797)
* Fix StopStringCriteria to handle tokens above len(tokenizer) This fixes #35244 by clipping token IDs to be within the tokenizer's vocabulary size before performing the embedding lookup. This prevents index errors when model.config.vocab_size > len(tokenizer). The fix: 1. Adds a clamp operation to ensure token IDs are within bounds 2. Adds a test case to verify the behavior * Use self.stop_strings instead of stop_strings * Handle clipping correctly * make fixup * Update test to the new embedding vecs * Use much bigger values in the mismatch test * Typo fix * Slight simplification --------- Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
@@ -176,6 +176,18 @@ class StoppingCriteriaTestCase(unittest.TestCase):
|
||||
for i in range(len(false_strings)):
|
||||
self.assertFalse(criteria(false_input_ids["input_ids"][i : i + 1], scores))
|
||||
|
||||
def test_stop_string_criteria_vocab_size_mismatch(self):
|
||||
"""Test that StopStringCriteria handles tokens above len(tokenizer) correctly."""
|
||||
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
|
||||
|
||||
# Create input_ids with tokens above len(tokenizer)
|
||||
input_ids = torch.tensor([[len(tokenizer) + 1024, 1, 2]], device=torch_device)
|
||||
scores = None
|
||||
criteria = StopStringCriteria(tokenizer=tokenizer, stop_strings=["test"])
|
||||
|
||||
# This should not raise an error and should return False since no stop string is matched
|
||||
self.assertFalse(criteria(input_ids, scores))
|
||||
|
||||
def test_stop_string_matching_positions(self):
|
||||
stop_string = "stop"
|
||||
token_list = ["last", "top", "topper", "s", "p"]
|
||||
@@ -200,14 +212,14 @@ class StoppingCriteriaTestCase(unittest.TestCase):
|
||||
|
||||
# Positions inside the stop string where the token matches (excluding end overlaps)
|
||||
valid_positions = embedding_vec[:, 0].tolist()
|
||||
self.assertEqual(valid_positions, [2, -1, -1, 3, -1])
|
||||
self.assertEqual(valid_positions, [2, -1, -1, 3, -1, -1])
|
||||
|
||||
# Overlap lengths between end of stop string and start of token
|
||||
end_overlaps = embedding_vec[:, 1].tolist()
|
||||
self.assertEqual(end_overlaps, [-1, 3, 3, -1, 1])
|
||||
self.assertEqual(end_overlaps, [-1, 3, 3, -1, 1, -1])
|
||||
|
||||
# Length of each token
|
||||
token_lengths = embedding_vec[:, 2].tolist()
|
||||
token_lengths = embedding_vec[:-1, 2].tolist()
|
||||
self.assertEqual(token_lengths, [len(token) for token in token_list])
|
||||
|
||||
def test_single_letter_stop_string(self):
|
||||
|
||||
Reference in New Issue
Block a user