prune LM Head for USD (#36695)
* initial commit * fix * fix style * set default to prune * add tests * comment * remove prune flag from generate * address Joao's comments * deprecate_kwarg * add doc * fix target_vocab_size * Update src/transformers/generation/candidate_generator.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update src/transformers/generation/candidate_generator.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update src/transformers/generation/candidate_generator.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update src/transformers/generation/candidate_generator.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * fix deprecated argument assistant_model_device --------- Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
@@ -20,6 +20,7 @@ class TestAssistantToTargetTranslator(unittest.TestCase):
|
||||
# Create mock tokenizers with predefined vocabularies
|
||||
self.target_tokenizer = MagicMock()
|
||||
self.assistant_tokenizer = MagicMock()
|
||||
self.assistant_model = MagicMock(device=torch_device)
|
||||
|
||||
# Define mock vocabularies for the tokenizers
|
||||
self.target_vocab = {"hello": 0, "world": 1, "foo": 2, "bar": 3}
|
||||
@@ -27,15 +28,15 @@ class TestAssistantToTargetTranslator(unittest.TestCase):
|
||||
|
||||
self.target_tokenizer.get_vocab.return_value = self.target_vocab
|
||||
self.assistant_tokenizer.get_vocab.return_value = self.assistant_vocab
|
||||
self.assistant_model_device = torch_device
|
||||
self.target_vocab_size = 6
|
||||
|
||||
# Instantiate the class under test
|
||||
self.translator = AssistantToTargetTranslator(
|
||||
target_tokenizer=self.target_tokenizer,
|
||||
assistant_tokenizer=self.assistant_tokenizer,
|
||||
assistant_model_device=self.assistant_model_device,
|
||||
target_vocab_size=self.target_vocab_size,
|
||||
assistant_model=self.assistant_model,
|
||||
assistant_prune_lm_head=False,
|
||||
)
|
||||
|
||||
def test_get_assistant_to_target_input_ids(self):
|
||||
@@ -53,19 +54,19 @@ class TestAssistantToTargetTranslator(unittest.TestCase):
|
||||
def test_get_target_ids(self):
|
||||
"""Test the translation of assistant candidate IDs to target candidate IDs."""
|
||||
assistant_input_ids = torch.LongTensor([[0, 1, 2]]).to(
|
||||
self.assistant_model_device
|
||||
self.assistant_model.device
|
||||
) # 'hello world foo' in assistant tokenizer
|
||||
target_input_ids = torch.LongTensor([[0, 1, 2]]).to(
|
||||
self.assistant_model_device
|
||||
self.assistant_model.device
|
||||
) # 'hello world foo' in target tokenizer
|
||||
assistant_candidate_ids = torch.LongTensor([[0, 1, 2, 4]]).to(
|
||||
self.assistant_model_device
|
||||
self.assistant_model.device
|
||||
) # 'hello world foo baz' in assistant tokenizer
|
||||
|
||||
expected_target_ids = torch.LongTensor(
|
||||
[[0, 1, 2, self.translator.SUPPRESS_TOKEN_ID]]
|
||||
).to(
|
||||
self.assistant_model_device
|
||||
self.assistant_model.device
|
||||
) # 'hello world foo baz' in target tokenizer (baz is mapped to self.translator.suppress_tokens_id since it does not exist in target vocab)
|
||||
|
||||
actual_target_ids = self.translator.get_target_ids(
|
||||
@@ -77,12 +78,12 @@ class TestAssistantToTargetTranslator(unittest.TestCase):
|
||||
"""Test the conversion of assistant logits to target logits."""
|
||||
# Assistant logits for IDs 0, 1, 2
|
||||
assistant_logits = torch.FloatTensor([[[0.1, 0.2, 0.3, 0.4, self.translator.FILTER_VALUE]]]).to(
|
||||
self.assistant_model_device
|
||||
self.assistant_model.device
|
||||
) # Shape (1, 1, 5)
|
||||
|
||||
# Expected target logits (target_vocab_size = 4)
|
||||
expected_target_logits = torch.full((1, 1, self.target_vocab_size), self.translator.FILTER_VALUE).to(
|
||||
self.assistant_model_device
|
||||
self.assistant_model.device
|
||||
)
|
||||
expected_target_logits[0, 0, 0] = 0.1 # 'hello'
|
||||
expected_target_logits[0, 0, 1] = 0.2 # 'world'
|
||||
@@ -119,7 +120,8 @@ class TestAssistantVocabTranslatorCache(unittest.TestCase):
|
||||
self.assistant_tokenizer = MockTokenizer({"hello": 0, "world": 1, "foo": 2})
|
||||
self.other_target_tokenizer = MockTokenizer({"foo": 2, "bar": 3})
|
||||
self.other_assistant_tokenizer = MockTokenizer({"baz": 4, "qux": 5})
|
||||
self.assistant_model_device = torch_device
|
||||
self.assistant_model = MagicMock(device=torch_device)
|
||||
|
||||
self.target_vocab_size = 6
|
||||
|
||||
def test_same_instance_for_same_tokenizers(self):
|
||||
@@ -127,14 +129,16 @@ class TestAssistantVocabTranslatorCache(unittest.TestCase):
|
||||
translator1 = AssistantVocabTranslatorCache.get_translator(
|
||||
self.target_tokenizer,
|
||||
self.assistant_tokenizer,
|
||||
assistant_model_device=self.assistant_model_device,
|
||||
target_vocab_size=self.target_vocab_size,
|
||||
assistant_model=self.assistant_model,
|
||||
assistant_prune_lm_head=False,
|
||||
)
|
||||
translator2 = AssistantVocabTranslatorCache.get_translator(
|
||||
self.target_tokenizer,
|
||||
self.assistant_tokenizer,
|
||||
assistant_model_device=self.assistant_model_device,
|
||||
target_vocab_size=self.target_vocab_size,
|
||||
assistant_model=self.assistant_model,
|
||||
assistant_prune_lm_head=False,
|
||||
)
|
||||
self.assertIs(translator1, translator2, "Translators should be cached and identical")
|
||||
|
||||
@@ -143,14 +147,16 @@ class TestAssistantVocabTranslatorCache(unittest.TestCase):
|
||||
translator1 = AssistantVocabTranslatorCache.get_translator(
|
||||
self.target_tokenizer,
|
||||
self.assistant_tokenizer,
|
||||
assistant_model_device=self.assistant_model_device,
|
||||
target_vocab_size=self.target_vocab_size,
|
||||
assistant_model=self.assistant_model,
|
||||
assistant_prune_lm_head=False,
|
||||
)
|
||||
translator2 = AssistantVocabTranslatorCache.get_translator(
|
||||
self.other_target_tokenizer,
|
||||
self.other_assistant_tokenizer,
|
||||
assistant_model_device=self.assistant_model_device,
|
||||
target_vocab_size=self.target_vocab_size,
|
||||
assistant_model=self.assistant_model,
|
||||
assistant_prune_lm_head=False,
|
||||
)
|
||||
self.assertIsNot(translator1, translator2, "Translators should differ for different tokenizers")
|
||||
|
||||
@@ -164,8 +170,9 @@ class TestAssistantVocabTranslatorCache(unittest.TestCase):
|
||||
translator = AssistantVocabTranslatorCache.get_translator(
|
||||
target_tokenizer,
|
||||
assistant_tokenizer,
|
||||
assistant_model_device=self.assistant_model_device,
|
||||
target_vocab_size=self.target_vocab_size,
|
||||
assistant_model=self.assistant_model,
|
||||
assistant_prune_lm_head=False,
|
||||
)
|
||||
self.assertEqual(len(AssistantVocabTranslatorCache._cache), initial_cache_size + 1)
|
||||
|
||||
@@ -192,8 +199,9 @@ class TestAssistantVocabTranslatorCache(unittest.TestCase):
|
||||
translator = AssistantVocabTranslatorCache.get_translator(
|
||||
target_tokenizer,
|
||||
assistant_tokenizer,
|
||||
assistant_model_device=self.assistant_model_device,
|
||||
target_vocab_size=self.target_vocab_size,
|
||||
assistant_model=self.assistant_model,
|
||||
assistant_prune_lm_head=False,
|
||||
)
|
||||
# Create weak references before returning
|
||||
refs = (weakref.ref(translator), weakref.ref(target_tokenizer), weakref.ref(assistant_tokenizer))
|
||||
@@ -239,16 +247,18 @@ class TestUniversalSpeculativeDecoding(unittest.TestCase):
|
||||
self.target_tokenizer.bos_token_id = self.target_tokenizer.eos_token_id
|
||||
if self.assistant_tokenizer.pad_token_id is None:
|
||||
self.assistant_tokenizer.pad_token_id = self.assistant_tokenizer.eos_token_id
|
||||
if self.target_tokenizer.bos_token_id is None:
|
||||
if self.assistant_tokenizer.bos_token_id is None:
|
||||
self.assistant_tokenizer.bos_token_id = self.assistant_tokenizer.eos_token_id
|
||||
|
||||
self.input_ids = torch.tensor([[1, 2, 3]]).to(torch_device)
|
||||
self.model_kwargs = {
|
||||
"attention_mask": torch.ones_like(self.input_ids).to(torch_device),
|
||||
}
|
||||
|
||||
atm_translator = AssistantVocabTranslatorCache.get_translator(
|
||||
self.target_tokenizer, self.assistant_tokenizer, self.target_config.vocab_size, torch_device
|
||||
target_tokenizer=self.target_tokenizer,
|
||||
assistant_tokenizer=self.assistant_tokenizer,
|
||||
assistant_model=self.assistant_model,
|
||||
target_vocab_size=self.target_config.vocab_size,
|
||||
)
|
||||
self.generator = UniversalSpeculativeDecodingGenerator(
|
||||
input_ids=self.input_ids,
|
||||
@@ -286,7 +296,7 @@ class TestUniversalSpeculativeDecoding(unittest.TestCase):
|
||||
)
|
||||
input_ids = torch.tensor([[self.target_tokenizer.convert_tokens_to_ids(missing_token)]])
|
||||
self.generator.input_ids = input_ids
|
||||
candidates, scores = self.generator.get_candidates(input_ids)
|
||||
candidates, _ = self.generator.get_candidates(input_ids)
|
||||
self.assertIsNotNone(candidates)
|
||||
|
||||
def test_speculation_depth(self):
|
||||
@@ -296,7 +306,7 @@ class TestUniversalSpeculativeDecoding(unittest.TestCase):
|
||||
|
||||
for depth in [1, 8, 17]:
|
||||
self.generator.num_assistant_tokens = depth
|
||||
candidates, scores = self.generator.get_candidates(input_ids)
|
||||
candidates, _ = self.generator.get_candidates(input_ids)
|
||||
self.assertLessEqual(candidates.shape[1] - input_ids.shape[1], depth)
|
||||
|
||||
def test_device_consistency(self):
|
||||
@@ -310,10 +320,6 @@ class TestUniversalSpeculativeDecoding(unittest.TestCase):
|
||||
"""Test that USD matches vanilla sampling with temperature set to nearly 0"""
|
||||
prompt = "Test text"
|
||||
|
||||
pipe_usd = pipeline("text-generation", model=cls.target_name, assistant_model=cls.assistant_name)
|
||||
pipe_usd_output = pipe_usd(prompt, max_new_tokens=5, do_sample=True, temperature=1e-9) # Nearly 0 temperature
|
||||
usd_text = pipe_usd_output[0]["generated_text"]
|
||||
|
||||
pipe_vanilla = pipeline(
|
||||
"text-generation",
|
||||
model=cls.target_name,
|
||||
@@ -321,5 +327,13 @@ class TestUniversalSpeculativeDecoding(unittest.TestCase):
|
||||
pipe_vanilla_output = pipe_vanilla(prompt, max_new_tokens=5, do_sample=False)
|
||||
vanilla_text = pipe_vanilla_output[0]["generated_text"]
|
||||
|
||||
pipe_usd = pipeline(
|
||||
"text-generation",
|
||||
model=cls.target_name,
|
||||
assistant_model=cls.assistant_name,
|
||||
)
|
||||
pipe_usd_output = pipe_usd(prompt, max_new_tokens=5, do_sample=True, temperature=1e-9) # Nearly 0 temperature
|
||||
usd_text = pipe_usd_output[0]["generated_text"]
|
||||
|
||||
# Assert that the outputs match
|
||||
cls.assertEqual(usd_text, vanilla_text)
|
||||
|
||||
Reference in New Issue
Block a user