From 655bec2da7120a8681acc2ce951f8d58c6f0e6ef Mon Sep 17 00:00:00 2001 From: kang sheng Date: Tue, 29 Oct 2024 16:39:06 +0800 Subject: [PATCH] use a tinymodel to test generation config which aviod timeout (#34482) * use a tinymodel to test generation config which aviod timeout * remove tailing whitespace --- tests/utils/test_modeling_utils.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/tests/utils/test_modeling_utils.py b/tests/utils/test_modeling_utils.py index 8af47cde8e..0452a10d5d 100644 --- a/tests/utils/test_modeling_utils.py +++ b/tests/utils/test_modeling_utils.py @@ -1544,15 +1544,16 @@ class ModelUtilsTest(TestCasePlus): self.assertEqual(model.__class__.__name__, model_ref.__class__.__name__) def test_generation_config_is_loaded_with_model(self): - # Note: `TinyLlama/TinyLlama-1.1B-Chat-v1.0` has a `generation_config.json` containing `max_length: 2048` + # Note: `hf-internal-testing/tiny-random-MistralForCausalLM` has a `generation_config.json` + # containing `bos_token_id: 1` # 1. Load without further parameters - model = AutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0") - self.assertEqual(model.generation_config.max_length, 2048) + model = AutoModelForCausalLM.from_pretrained(TINY_MISTRAL) + self.assertEqual(model.generation_config.bos_token_id, 1) # 2. Load with `device_map` - model = AutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0", device_map="auto") - self.assertEqual(model.generation_config.max_length, 2048) + model = AutoModelForCausalLM.from_pretrained(TINY_MISTRAL, device_map="auto") + self.assertEqual(model.generation_config.bos_token_id, 1) @require_safetensors def test_safetensors_torch_from_torch(self):