Refactoring AssistedCandidateGenerator for Improved Modularity and Reusability (#35009)
* move `TestAssistedCandidateGeneratorDifferentTokenizers` into a new testing file * refactor * NOTHING. add space to rerun github actions tests * remove it... * NOTHING. add space to rerun github actions tests * remove it... * replace: `self.prev_tokens` -> `self.prev_assistant_ids` * NOTHING. rerun CI tests * remove it * introduce `self.prev_target_ids_len` * fix style * fix style --------- Co-authored-by: Jonathan Mamou <jonathan.mamou@intel.com>
This commit is contained in:
43
tests/generation/test_candidate_generator.py
Normal file
43
tests/generation/test_candidate_generator.py
Normal file
@@ -0,0 +1,43 @@
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers.generation.candidate_generator import AssistedCandidateGeneratorDifferentTokenizers
|
||||
|
||||
|
||||
class TestAssistedCandidateGeneratorDifferentTokenizers(unittest.TestCase):
|
||||
def test_no_intersection(self):
|
||||
prompt = np.array([[1, 2, 3]])
|
||||
prompt_plus_new_tokens = np.array([[4, 5, 6]])
|
||||
result = AssistedCandidateGeneratorDifferentTokenizers._get_tokens_diag(prompt, prompt_plus_new_tokens)
|
||||
self.assertEqual(result, (None, None, None))
|
||||
|
||||
def test_complete_overlap(self):
|
||||
prompt = np.array([[1, 2, 3]])
|
||||
prompt_plus_new_tokens = np.array([[1, 2, 3, 4, 5]])
|
||||
discrep_length, new_tokens_only, discrep_only = AssistedCandidateGeneratorDifferentTokenizers._get_tokens_diag(
|
||||
prompt, prompt_plus_new_tokens
|
||||
)
|
||||
self.assertEqual(discrep_length, 0)
|
||||
np.testing.assert_array_equal(new_tokens_only, np.array([[4, 5]]))
|
||||
np.testing.assert_array_equal(discrep_only, np.array([[]]))
|
||||
|
||||
def test_partial_overlap(self):
|
||||
prompt = np.array([[1, 2, 3]])
|
||||
prompt_plus_new_tokens = np.array([[2, 3, 4, 5]])
|
||||
discrep_length, new_tokens_only, discrep_only = AssistedCandidateGeneratorDifferentTokenizers._get_tokens_diag(
|
||||
prompt, prompt_plus_new_tokens
|
||||
)
|
||||
self.assertEqual(discrep_length, 0)
|
||||
np.testing.assert_array_equal(new_tokens_only, np.array([[4, 5]]))
|
||||
np.testing.assert_array_equal(discrep_only, np.array([[]]))
|
||||
|
||||
def test_no_new_tokens(self):
|
||||
prompt = np.array([[1, 2, 3]])
|
||||
prompt_plus_new_tokens = np.array([[1, 2, 3]])
|
||||
discrep_length, new_tokens_only, discrep_only = AssistedCandidateGeneratorDifferentTokenizers._get_tokens_diag(
|
||||
prompt, prompt_plus_new_tokens
|
||||
)
|
||||
self.assertEqual(discrep_length, 0)
|
||||
np.testing.assert_array_equal(new_tokens_only, np.array([[]]))
|
||||
np.testing.assert_array_equal(discrep_only, np.array([[]]))
|
||||
Reference in New Issue
Block a user