Test loading generation config with safetensor weights (#31550)
fix test
This commit is contained in:
@@ -1424,20 +1424,15 @@ class ModelUtilsTest(TestCasePlus):
|
|||||||
self.assertEqual(model.__class__.__name__, model_ref.__class__.__name__)
|
self.assertEqual(model.__class__.__name__, model_ref.__class__.__name__)
|
||||||
|
|
||||||
def test_generation_config_is_loaded_with_model(self):
|
def test_generation_config_is_loaded_with_model(self):
|
||||||
# Note: `joaogante/tiny-random-gpt2-with-generation-config` has a `generation_config.json` containing a dummy
|
# Note: `TinyLlama/TinyLlama-1.1B-Chat-v1.0` has a `generation_config.json` containing `max_length: 2048`
|
||||||
# `transformers_version` field set to `foo`. If loading the file fails, this test also fails.
|
|
||||||
|
|
||||||
# 1. Load without further parameters
|
# 1. Load without further parameters
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
|
||||||
"joaogante/tiny-random-gpt2-with-generation-config", use_safetensors=False
|
self.assertEqual(model.generation_config.max_length, 2048)
|
||||||
)
|
|
||||||
self.assertEqual(model.generation_config.transformers_version, "foo")
|
|
||||||
|
|
||||||
# 2. Load with `device_map`
|
# 2. Load with `device_map`
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0", device_map="auto")
|
||||||
"joaogante/tiny-random-gpt2-with-generation-config", device_map="auto", use_safetensors=False
|
self.assertEqual(model.generation_config.max_length, 2048)
|
||||||
)
|
|
||||||
self.assertEqual(model.generation_config.transformers_version, "foo")
|
|
||||||
|
|
||||||
@require_safetensors
|
@require_safetensors
|
||||||
def test_safetensors_torch_from_torch(self):
|
def test_safetensors_torch_from_torch(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user