Generate: get generation mode from the generation config instance 🧼 (#29441)
This commit is contained in:
@@ -24,6 +24,7 @@ from parameterized import parameterized
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
from transformers import AutoConfig, GenerationConfig
|
||||
from transformers.generation import GenerationMode
|
||||
from transformers.testing_utils import TOKEN, USER, is_staging_test
|
||||
|
||||
|
||||
@@ -202,6 +203,23 @@ class GenerationConfigTest(unittest.TestCase):
|
||||
self.assertEqual(len(captured_warnings), 0)
|
||||
self.assertTrue(len(os.listdir(tmp_dir)) == 1)
|
||||
|
||||
def test_generation_mode(self):
|
||||
"""Tests that the `get_generation_mode` method is working as expected."""
|
||||
config = GenerationConfig()
|
||||
self.assertEqual(config.get_generation_mode(), GenerationMode.GREEDY_SEARCH)
|
||||
|
||||
config = GenerationConfig(do_sample=True)
|
||||
self.assertEqual(config.get_generation_mode(), GenerationMode.SAMPLE)
|
||||
|
||||
config = GenerationConfig(num_beams=2)
|
||||
self.assertEqual(config.get_generation_mode(), GenerationMode.BEAM_SEARCH)
|
||||
|
||||
config = GenerationConfig(top_k=10, do_sample=False, penalty_alpha=0.6)
|
||||
self.assertEqual(config.get_generation_mode(), GenerationMode.CONTRASTIVE_SEARCH)
|
||||
|
||||
config = GenerationConfig()
|
||||
self.assertEqual(config.get_generation_mode(assistant_model="foo"), GenerationMode.ASSISTED_GENERATION)
|
||||
|
||||
|
||||
@is_staging_test
|
||||
class ConfigPushToHubTester(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user