use a tinymodel to test generation config which aviod timeout (#34482)

* use a tinymodel to test generation config which aviod timeout

* remove tailing whitespace
This commit is contained in:
kang sheng
2024-10-29 16:39:06 +08:00
committed by GitHub
parent 63ca6d9771
commit 655bec2da7

View File

@@ -1544,15 +1544,16 @@ class ModelUtilsTest(TestCasePlus):
self.assertEqual(model.__class__.__name__, model_ref.__class__.__name__) self.assertEqual(model.__class__.__name__, model_ref.__class__.__name__)
def test_generation_config_is_loaded_with_model(self): def test_generation_config_is_loaded_with_model(self):
# Note: `TinyLlama/TinyLlama-1.1B-Chat-v1.0` has a `generation_config.json` containing `max_length: 2048` # Note: `hf-internal-testing/tiny-random-MistralForCausalLM` has a `generation_config.json`
# containing `bos_token_id: 1`
# 1. Load without further parameters # 1. Load without further parameters
model = AutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0") model = AutoModelForCausalLM.from_pretrained(TINY_MISTRAL)
self.assertEqual(model.generation_config.max_length, 2048) self.assertEqual(model.generation_config.bos_token_id, 1)
# 2. Load with `device_map` # 2. Load with `device_map`
model = AutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0", device_map="auto") model = AutoModelForCausalLM.from_pretrained(TINY_MISTRAL, device_map="auto")
self.assertEqual(model.generation_config.max_length, 2048) self.assertEqual(model.generation_config.bos_token_id, 1)
@require_safetensors @require_safetensors
def test_safetensors_torch_from_torch(self): def test_safetensors_torch_from_torch(self):