Support for transformers explicit filename (#38152)
* Support for transformers explicit filename * Tests * Rerun tests
This commit is contained in:
@@ -1958,6 +1958,80 @@ class ModelUtilsTest(TestCasePlus):
|
||||
except subprocess.CalledProcessError as e:
|
||||
raise Exception(f"The following error was captured: {e.stderr}")
|
||||
|
||||
def test_explicit_transformers_weights(self):
|
||||
"""
|
||||
Transformers supports loading from repos where the weights file is explicitly set in the config.
|
||||
When loading a config file, transformers will see whether `transformers_weights` is defined in the config.
|
||||
If so, it will load from that file.
|
||||
|
||||
Here, we ensure that the correct file is loaded.
|
||||
"""
|
||||
model = BertModel.from_pretrained("hf-internal-testing/explicit_transformers_weight_in_config")
|
||||
self.assertEqual(model.num_parameters(), 87929)
|
||||
|
||||
def test_explicit_transformers_weights_index(self):
|
||||
"""
|
||||
Transformers supports loading from repos where the weights file is explicitly set in the config.
|
||||
When loading a config file, transformers will see whether `transformers_weights` is defined in the config.
|
||||
If so, it will load from that file.
|
||||
|
||||
Here, we ensure that the correct file is loaded, given the file is an index of multiple weights.
|
||||
"""
|
||||
model = BertModel.from_pretrained("hf-internal-testing/explicit_transformers_weight_in_config_sharded")
|
||||
self.assertEqual(model.num_parameters(), 87929)
|
||||
|
||||
def test_explicit_transformers_weights_save_and_reload(self):
|
||||
"""
|
||||
Transformers supports loading from repos where the weights file is explicitly set in the config.
|
||||
When loading a config file, transformers will see whether `transformers_weights` is defined in the config.
|
||||
If so, it will load from that file.
|
||||
|
||||
When saving the model, we should be careful not to safe the `transformers_weights` attribute in the config;
|
||||
otherwise, transformers will try to load from that file whereas it should simply load from the default file.
|
||||
|
||||
We test that for a non-sharded repo.
|
||||
"""
|
||||
model = BertModel.from_pretrained("hf-internal-testing/explicit_transformers_weight_in_config")
|
||||
explicit_transformers_weights = model.config.transformers_weights
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
|
||||
# The config should not have a mention of transformers_weights
|
||||
with open(os.path.join(tmpdirname, "config.json")) as f:
|
||||
config = json.loads(f.read())
|
||||
self.assertFalse("transformers_weights" in config)
|
||||
|
||||
# The serialized weights should be in model.safetensors and not the transformers_weights
|
||||
self.assertTrue(explicit_transformers_weights not in os.listdir(tmpdirname))
|
||||
self.assertTrue("model.safetensors" in os.listdir(tmpdirname))
|
||||
|
||||
def test_explicit_transformers_weights_index_save_and_reload(self):
|
||||
"""
|
||||
Transformers supports loading from repos where the weights file is explicitly set in the config.
|
||||
When loading a config file, transformers will see whether `transformers_weights` is defined in the config.
|
||||
If so, it will load from that file.
|
||||
|
||||
When saving the model, we should be careful not to safe the `transformers_weights` attribute in the config;
|
||||
otherwise, transformers will try to load from that file whereas it should simply load from the default file.
|
||||
|
||||
We test that for a sharded repo.
|
||||
"""
|
||||
model = BertModel.from_pretrained("hf-internal-testing/explicit_transformers_weight_in_config_sharded")
|
||||
explicit_transformers_weights = model.config.transformers_weights
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname, max_shard_size="100kb")
|
||||
|
||||
# The config should not have a mention of transformers_weights
|
||||
with open(os.path.join(tmpdirname, "config.json")) as f:
|
||||
config = json.loads(f.read())
|
||||
self.assertFalse("transformers_weights" in config)
|
||||
|
||||
# The serialized weights should be in model.safetensors and not the transformers_weights
|
||||
self.assertTrue(explicit_transformers_weights not in os.listdir(tmpdirname))
|
||||
self.assertTrue("model.safetensors.index.json" in os.listdir(tmpdirname))
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
|
||||
Reference in New Issue
Block a user