[core/ FEAT] Add the possibility to push custom tags using PreTrainedModel itself (#28405)
* v1 tags * remove unneeded conversion * v2 * rm unneeded warning * add more utility methods * Update src/transformers/utils/hub.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/utils/hub.py Co-authored-by: Lucain <lucainp@gmail.com> * Update src/transformers/utils/hub.py Co-authored-by: Lucain <lucainp@gmail.com> * more enhancements * oops * merge tags * clean up * revert unneeded change * add extensive docs * more docs * more kwargs * add test * oops * fix test * Update src/transformers/modeling_utils.py Co-authored-by: Omar Sanseviero <osanseviero@gmail.com> * Update src/transformers/utils/hub.py Co-authored-by: Lucain <lucainp@gmail.com> * Update src/transformers/modeling_utils.py * Update src/transformers/trainer.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/modeling_utils.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * add more conditions * more logic --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> Co-authored-by: Lucain <lucainp@gmail.com> Co-authored-by: Omar Sanseviero <osanseviero@gmail.com>
This commit is contained in:
@@ -1435,6 +1435,11 @@ class ModelPushToHubTester(unittest.TestCase):
|
||||
except HTTPError:
|
||||
pass
|
||||
|
||||
try:
|
||||
delete_repo(token=cls._token, repo_id="test-dynamic-model-with-tags")
|
||||
except HTTPError:
|
||||
pass
|
||||
|
||||
@unittest.skip("This test is flaky")
|
||||
def test_push_to_hub(self):
|
||||
config = BertConfig(
|
||||
@@ -1522,6 +1527,28 @@ The commit description supports markdown synthax see:
|
||||
new_model = AutoModel.from_config(config, trust_remote_code=True)
|
||||
self.assertEqual(new_model.__class__.__name__, "CustomModel")
|
||||
|
||||
def test_push_to_hub_with_tags(self):
|
||||
from huggingface_hub import ModelCard
|
||||
|
||||
new_tags = ["tag-1", "tag-2"]
|
||||
|
||||
CustomConfig.register_for_auto_class()
|
||||
CustomModel.register_for_auto_class()
|
||||
|
||||
config = CustomConfig(hidden_size=32)
|
||||
model = CustomModel(config)
|
||||
|
||||
self.assertTrue(model.model_tags is None)
|
||||
|
||||
model.add_model_tags(new_tags)
|
||||
|
||||
self.assertTrue(model.model_tags == new_tags)
|
||||
|
||||
model.push_to_hub("test-dynamic-model-with-tags", token=self._token)
|
||||
|
||||
loaded_model_card = ModelCard.load(f"{USER}/test-dynamic-model-with-tags")
|
||||
self.assertEqual(loaded_model_card.data.tags, new_tags)
|
||||
|
||||
|
||||
@require_torch
|
||||
class AttentionMaskTester(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user