🚨🚨 Setting default behavior of assisted decoding (#33657)

This commit is contained in:
Jonathan Mamou
2024-09-25 11:39:09 +03:00
committed by GitHub
parent 5f0c181f4e
commit 52daf4ec76
4 changed files with 24 additions and 8 deletions

View File

@@ -2069,6 +2069,7 @@ class GenerationTesterMixin:
"assistant_model": assistant_model,
}
assistant_model.generation_config.assistant_confidence_threshold = None
# Setting num_logits_to_keep at 0 keeps all logits (old behavior)
with_all_logits = model.generate(
input_ids, attention_mask=attention_mask, **generation_kwargs, **inputs_dict, num_logits_to_keep=0
@@ -3098,6 +3099,16 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
)
self.assertEqual(len(warning_list), 0)
def test_default_assisted_generation(self):
# Initialize the GenerationConfig object
config = GenerationConfig()
# Check the default values
self.assertEqual(config.num_assistant_tokens, 20)
self.assertEqual(config.num_assistant_tokens_schedule, "constant")
self.assertEqual(config.assistant_confidence_threshold, 0.4)
self.assertEqual(config.is_assistant, False)
def test_generated_length_assisted_generation(self):
# PT-only test: TF doesn't support assisted decoding yet.
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)