Generate: get generation mode from the generation config instance 🧼 (#29441)

This commit is contained in:
Joao Gante
2024-03-06 11:18:35 +00:00
committed by GitHub
parent 41f7b7ae4b
commit 700d48fb2d
5 changed files with 103 additions and 65 deletions

View File

@@ -24,6 +24,7 @@ from parameterized import parameterized
from requests.exceptions import HTTPError
from transformers import AutoConfig, GenerationConfig
from transformers.generation import GenerationMode
from transformers.testing_utils import TOKEN, USER, is_staging_test
@@ -202,6 +203,23 @@ class GenerationConfigTest(unittest.TestCase):
self.assertEqual(len(captured_warnings), 0)
self.assertTrue(len(os.listdir(tmp_dir)) == 1)
def test_generation_mode(self):
"""Tests that the `get_generation_mode` method is working as expected."""
config = GenerationConfig()
self.assertEqual(config.get_generation_mode(), GenerationMode.GREEDY_SEARCH)
config = GenerationConfig(do_sample=True)
self.assertEqual(config.get_generation_mode(), GenerationMode.SAMPLE)
config = GenerationConfig(num_beams=2)
self.assertEqual(config.get_generation_mode(), GenerationMode.BEAM_SEARCH)
config = GenerationConfig(top_k=10, do_sample=False, penalty_alpha=0.6)
self.assertEqual(config.get_generation_mode(), GenerationMode.CONTRASTIVE_SEARCH)
config = GenerationConfig()
self.assertEqual(config.get_generation_mode(assistant_model="foo"), GenerationMode.ASSISTED_GENERATION)
@is_staging_test
class ConfigPushToHubTester(unittest.TestCase):