🚨🚨 Setting default behavior of assisted decoding (#33657)
This commit is contained in:
@@ -2069,6 +2069,7 @@ class GenerationTesterMixin:
|
||||
"assistant_model": assistant_model,
|
||||
}
|
||||
|
||||
assistant_model.generation_config.assistant_confidence_threshold = None
|
||||
# Setting num_logits_to_keep at 0 keeps all logits (old behavior)
|
||||
with_all_logits = model.generate(
|
||||
input_ids, attention_mask=attention_mask, **generation_kwargs, **inputs_dict, num_logits_to_keep=0
|
||||
@@ -3098,6 +3099,16 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
)
|
||||
self.assertEqual(len(warning_list), 0)
|
||||
|
||||
def test_default_assisted_generation(self):
|
||||
# Initialize the GenerationConfig object
|
||||
config = GenerationConfig()
|
||||
|
||||
# Check the default values
|
||||
self.assertEqual(config.num_assistant_tokens, 20)
|
||||
self.assertEqual(config.num_assistant_tokens_schedule, "constant")
|
||||
self.assertEqual(config.assistant_confidence_threshold, 0.4)
|
||||
self.assertEqual(config.is_assistant, False)
|
||||
|
||||
def test_generated_length_assisted_generation(self):
|
||||
# PT-only test: TF doesn't support assisted decoding yet.
|
||||
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
|
||||
|
||||
Reference in New Issue
Block a user