Support for transformers explicit filename (#38152)

* Support for transformers explicit filename

* Tests

* Rerun tests
This commit is contained in:
Lysandre Debut
2025-05-19 14:33:47 +02:00
committed by GitHub
parent dbb9813dff
commit 003deb16f1
3 changed files with 101 additions and 2 deletions

View File

@@ -1958,6 +1958,80 @@ class ModelUtilsTest(TestCasePlus):
except subprocess.CalledProcessError as e:
raise Exception(f"The following error was captured: {e.stderr}")
def test_explicit_transformers_weights(self):
"""
Transformers supports loading from repos where the weights file is explicitly set in the config.
When loading a config file, transformers will see whether `transformers_weights` is defined in the config.
If so, it will load from that file.
Here, we ensure that the correct file is loaded.
"""
model = BertModel.from_pretrained("hf-internal-testing/explicit_transformers_weight_in_config")
self.assertEqual(model.num_parameters(), 87929)
def test_explicit_transformers_weights_index(self):
"""
Transformers supports loading from repos where the weights file is explicitly set in the config.
When loading a config file, transformers will see whether `transformers_weights` is defined in the config.
If so, it will load from that file.
Here, we ensure that the correct file is loaded, given the file is an index of multiple weights.
"""
model = BertModel.from_pretrained("hf-internal-testing/explicit_transformers_weight_in_config_sharded")
self.assertEqual(model.num_parameters(), 87929)
def test_explicit_transformers_weights_save_and_reload(self):
"""
Transformers supports loading from repos where the weights file is explicitly set in the config.
When loading a config file, transformers will see whether `transformers_weights` is defined in the config.
If so, it will load from that file.
When saving the model, we should be careful not to safe the `transformers_weights` attribute in the config;
otherwise, transformers will try to load from that file whereas it should simply load from the default file.
We test that for a non-sharded repo.
"""
model = BertModel.from_pretrained("hf-internal-testing/explicit_transformers_weight_in_config")
explicit_transformers_weights = model.config.transformers_weights
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
# The config should not have a mention of transformers_weights
with open(os.path.join(tmpdirname, "config.json")) as f:
config = json.loads(f.read())
self.assertFalse("transformers_weights" in config)
# The serialized weights should be in model.safetensors and not the transformers_weights
self.assertTrue(explicit_transformers_weights not in os.listdir(tmpdirname))
self.assertTrue("model.safetensors" in os.listdir(tmpdirname))
def test_explicit_transformers_weights_index_save_and_reload(self):
"""
Transformers supports loading from repos where the weights file is explicitly set in the config.
When loading a config file, transformers will see whether `transformers_weights` is defined in the config.
If so, it will load from that file.
When saving the model, we should be careful not to safe the `transformers_weights` attribute in the config;
otherwise, transformers will try to load from that file whereas it should simply load from the default file.
We test that for a sharded repo.
"""
model = BertModel.from_pretrained("hf-internal-testing/explicit_transformers_weight_in_config_sharded")
explicit_transformers_weights = model.config.transformers_weights
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname, max_shard_size="100kb")
# The config should not have a mention of transformers_weights
with open(os.path.join(tmpdirname, "config.json")) as f:
config = json.loads(f.read())
self.assertFalse("transformers_weights" in config)
# The serialized weights should be in model.safetensors and not the transformers_weights
self.assertTrue(explicit_transformers_weights not in os.listdir(tmpdirname))
self.assertTrue("model.safetensors.index.json" in os.listdir(tmpdirname))
@slow
@require_torch