[core/ FEAT] Add the possibility to push custom tags using PreTrainedModel itself (#28405)
* v1 tags * remove unneeded conversion * v2 * rm unneeded warning * add more utility methods * Update src/transformers/utils/hub.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/utils/hub.py Co-authored-by: Lucain <lucainp@gmail.com> * Update src/transformers/utils/hub.py Co-authored-by: Lucain <lucainp@gmail.com> * more enhancements * oops * merge tags * clean up * revert unneeded change * add extensive docs * more docs * more kwargs * add test * oops * fix test * Update src/transformers/modeling_utils.py Co-authored-by: Omar Sanseviero <osanseviero@gmail.com> * Update src/transformers/utils/hub.py Co-authored-by: Lucain <lucainp@gmail.com> * Update src/transformers/modeling_utils.py * Update src/transformers/trainer.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/modeling_utils.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * add more conditions * more logic --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> Co-authored-by: Lucain <lucainp@gmail.com> Co-authored-by: Omar Sanseviero <osanseviero@gmail.com>
This commit is contained in:
@@ -89,7 +89,7 @@ from .utils import (
|
|||||||
replace_return_docstrings,
|
replace_return_docstrings,
|
||||||
strtobool,
|
strtobool,
|
||||||
)
|
)
|
||||||
from .utils.hub import convert_file_size_to_int, get_checkpoint_shard_files
|
from .utils.hub import convert_file_size_to_int, create_and_tag_model_card, get_checkpoint_shard_files
|
||||||
from .utils.import_utils import (
|
from .utils.import_utils import (
|
||||||
ENV_VARS_TRUE_VALUES,
|
ENV_VARS_TRUE_VALUES,
|
||||||
is_sagemaker_mp_enabled,
|
is_sagemaker_mp_enabled,
|
||||||
@@ -1172,6 +1172,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
config_class = None
|
config_class = None
|
||||||
base_model_prefix = ""
|
base_model_prefix = ""
|
||||||
main_input_name = "input_ids"
|
main_input_name = "input_ids"
|
||||||
|
model_tags = None
|
||||||
|
|
||||||
_auto_class = None
|
_auto_class = None
|
||||||
_no_split_modules = None
|
_no_split_modules = None
|
||||||
_skip_keys_device_placement = None
|
_skip_keys_device_placement = None
|
||||||
@@ -1252,6 +1254,38 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
# Remove the attribute now that is has been consumed, so it's no saved in the config.
|
# Remove the attribute now that is has been consumed, so it's no saved in the config.
|
||||||
delattr(self.config, "gradient_checkpointing")
|
delattr(self.config, "gradient_checkpointing")
|
||||||
|
|
||||||
|
def add_model_tags(self, tags: Union[List[str], str]) -> None:
|
||||||
|
r"""
|
||||||
|
Add custom tags into the model that gets pushed to the Hugging Face Hub. Will
|
||||||
|
not overwrite existing tags in the model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tags (`Union[List[str], str]`):
|
||||||
|
The desired tags to inject in the model
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from transformers import AutoModel
|
||||||
|
|
||||||
|
model = AutoModel.from_pretrained("bert-base-cased")
|
||||||
|
|
||||||
|
model.add_model_tags(["custom", "custom-bert"])
|
||||||
|
|
||||||
|
# Push the model to your namespace with the name "my-custom-bert".
|
||||||
|
model.push_to_hub("my-custom-bert")
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
if isinstance(tags, str):
|
||||||
|
tags = [tags]
|
||||||
|
|
||||||
|
if self.model_tags is None:
|
||||||
|
self.model_tags = []
|
||||||
|
|
||||||
|
for tag in tags:
|
||||||
|
if tag not in self.model_tags:
|
||||||
|
self.model_tags.append(tag)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _from_config(cls, config, **kwargs):
|
def _from_config(cls, config, **kwargs):
|
||||||
"""
|
"""
|
||||||
@@ -2212,6 +2246,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
|
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
|
||||||
"""
|
"""
|
||||||
use_auth_token = kwargs.pop("use_auth_token", None)
|
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||||
|
ignore_metadata_errors = kwargs.pop("ignore_metadata_errors", False)
|
||||||
|
|
||||||
if use_auth_token is not None:
|
if use_auth_token is not None:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
@@ -2438,6 +2473,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
)
|
)
|
||||||
|
|
||||||
if push_to_hub:
|
if push_to_hub:
|
||||||
|
# Eventually create an empty model card
|
||||||
|
model_card = create_and_tag_model_card(
|
||||||
|
repo_id, self.model_tags, token=token, ignore_metadata_errors=ignore_metadata_errors
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update model card if needed:
|
||||||
|
model_card.save(os.path.join(save_directory, "README.md"))
|
||||||
|
|
||||||
self._upload_modified_files(
|
self._upload_modified_files(
|
||||||
save_directory,
|
save_directory,
|
||||||
repo_id,
|
repo_id,
|
||||||
@@ -2446,6 +2489,22 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
token=token,
|
token=token,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@wraps(PushToHubMixin.push_to_hub)
|
||||||
|
def push_to_hub(self, *args, **kwargs):
|
||||||
|
tags = self.model_tags if self.model_tags is not None else []
|
||||||
|
|
||||||
|
tags_kwargs = kwargs.get("tags", [])
|
||||||
|
if isinstance(tags_kwargs, str):
|
||||||
|
tags_kwargs = [tags_kwargs]
|
||||||
|
|
||||||
|
for tag in tags_kwargs:
|
||||||
|
if tag not in tags:
|
||||||
|
tags.append(tag)
|
||||||
|
|
||||||
|
if tags:
|
||||||
|
kwargs["tags"] = tags
|
||||||
|
return super().push_to_hub(*args, **kwargs)
|
||||||
|
|
||||||
def get_memory_footprint(self, return_buffers=True):
|
def get_memory_footprint(self, return_buffers=True):
|
||||||
r"""
|
r"""
|
||||||
Get the memory footprint of a model. This will return the memory footprint of the current model in bytes.
|
Get the memory footprint of a model. This will return the memory footprint of the current model in bytes.
|
||||||
|
|||||||
@@ -3581,6 +3581,15 @@ class Trainer:
|
|||||||
library_name = ModelCard.load(model_card_filepath).data.get("library_name")
|
library_name = ModelCard.load(model_card_filepath).data.get("library_name")
|
||||||
is_peft_library = library_name == "peft"
|
is_peft_library = library_name == "peft"
|
||||||
|
|
||||||
|
# Append existing tags in `tags`
|
||||||
|
existing_tags = ModelCard.load(model_card_filepath).data.tags
|
||||||
|
if tags is not None and existing_tags is not None:
|
||||||
|
if isinstance(tags, str):
|
||||||
|
tags = [tags]
|
||||||
|
for tag in existing_tags:
|
||||||
|
if tag not in tags:
|
||||||
|
tags.append(tag)
|
||||||
|
|
||||||
training_summary = TrainingSummary.from_trainer(
|
training_summary = TrainingSummary.from_trainer(
|
||||||
self,
|
self,
|
||||||
language=language,
|
language=language,
|
||||||
@@ -3699,6 +3708,18 @@ class Trainer:
|
|||||||
if not self.is_world_process_zero():
|
if not self.is_world_process_zero():
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# Add additional tags in the case the model has already some tags and users pass
|
||||||
|
# "tags" argument to `push_to_hub` so that trainer automatically handles internal tags
|
||||||
|
# from all models since Trainer does not call `model.push_to_hub`.
|
||||||
|
if "tags" in kwargs and getattr(self.model, "model_tags", None) is not None:
|
||||||
|
# If it is a string, convert it to a list
|
||||||
|
if isinstance(kwargs["tags"], str):
|
||||||
|
kwargs["tags"] = [kwargs["tags"]]
|
||||||
|
|
||||||
|
for model_tag in self.model.model_tags:
|
||||||
|
if model_tag not in kwargs["tags"]:
|
||||||
|
kwargs["tags"].append(model_tag)
|
||||||
|
|
||||||
self.create_model_card(model_name=model_name, **kwargs)
|
self.create_model_card(model_name=model_name, **kwargs)
|
||||||
|
|
||||||
# Wait for the current upload to be finished.
|
# Wait for the current upload to be finished.
|
||||||
|
|||||||
@@ -33,6 +33,8 @@ import requests
|
|||||||
from huggingface_hub import (
|
from huggingface_hub import (
|
||||||
_CACHED_NO_EXIST,
|
_CACHED_NO_EXIST,
|
||||||
CommitOperationAdd,
|
CommitOperationAdd,
|
||||||
|
ModelCard,
|
||||||
|
ModelCardData,
|
||||||
constants,
|
constants,
|
||||||
create_branch,
|
create_branch,
|
||||||
create_commit,
|
create_commit,
|
||||||
@@ -762,6 +764,7 @@ class PushToHubMixin:
|
|||||||
safe_serialization: bool = True,
|
safe_serialization: bool = True,
|
||||||
revision: str = None,
|
revision: str = None,
|
||||||
commit_description: str = None,
|
commit_description: str = None,
|
||||||
|
tags: Optional[List[str]] = None,
|
||||||
**deprecated_kwargs,
|
**deprecated_kwargs,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
@@ -795,6 +798,8 @@ class PushToHubMixin:
|
|||||||
Branch to push the uploaded files to.
|
Branch to push the uploaded files to.
|
||||||
commit_description (`str`, *optional*):
|
commit_description (`str`, *optional*):
|
||||||
The description of the commit that will be created
|
The description of the commit that will be created
|
||||||
|
tags (`List[str]`, *optional*):
|
||||||
|
List of tags to push on the Hub.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
|
|
||||||
@@ -811,6 +816,7 @@ class PushToHubMixin:
|
|||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
use_auth_token = deprecated_kwargs.pop("use_auth_token", None)
|
use_auth_token = deprecated_kwargs.pop("use_auth_token", None)
|
||||||
|
ignore_metadata_errors = deprecated_kwargs.pop("ignore_metadata_errors", False)
|
||||||
if use_auth_token is not None:
|
if use_auth_token is not None:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
|
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
|
||||||
@@ -855,6 +861,11 @@ class PushToHubMixin:
|
|||||||
repo_id, private=private, token=token, repo_url=repo_url, organization=organization
|
repo_id, private=private, token=token, repo_url=repo_url, organization=organization
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Create a new empty model card and eventually tag it
|
||||||
|
model_card = create_and_tag_model_card(
|
||||||
|
repo_id, tags, token=token, ignore_metadata_errors=ignore_metadata_errors
|
||||||
|
)
|
||||||
|
|
||||||
if use_temp_dir is None:
|
if use_temp_dir is None:
|
||||||
use_temp_dir = not os.path.isdir(working_dir)
|
use_temp_dir = not os.path.isdir(working_dir)
|
||||||
|
|
||||||
@@ -864,6 +875,9 @@ class PushToHubMixin:
|
|||||||
# Save all files.
|
# Save all files.
|
||||||
self.save_pretrained(work_dir, max_shard_size=max_shard_size, safe_serialization=safe_serialization)
|
self.save_pretrained(work_dir, max_shard_size=max_shard_size, safe_serialization=safe_serialization)
|
||||||
|
|
||||||
|
# Update model card if needed:
|
||||||
|
model_card.save(os.path.join(work_dir, "README.md"))
|
||||||
|
|
||||||
return self._upload_modified_files(
|
return self._upload_modified_files(
|
||||||
work_dir,
|
work_dir,
|
||||||
repo_id,
|
repo_id,
|
||||||
@@ -1081,6 +1095,43 @@ def extract_info_from_url(url):
|
|||||||
return {"repo": cache_repo, "revision": revision, "filename": filename}
|
return {"repo": cache_repo, "revision": revision, "filename": filename}
|
||||||
|
|
||||||
|
|
||||||
|
def create_and_tag_model_card(
|
||||||
|
repo_id: str,
|
||||||
|
tags: Optional[List[str]] = None,
|
||||||
|
token: Optional[str] = None,
|
||||||
|
ignore_metadata_errors: bool = False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Creates or loads an existing model card and tags it.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
repo_id (`str`):
|
||||||
|
The repo_id where to look for the model card.
|
||||||
|
tags (`List[str]`, *optional*):
|
||||||
|
The list of tags to add in the model card
|
||||||
|
token (`str`, *optional*):
|
||||||
|
Authentication token, obtained with `huggingface_hub.HfApi.login` method. Will default to the stored token.
|
||||||
|
ignore_metadata_errors (`str`):
|
||||||
|
If True, errors while parsing the metadata section will be ignored. Some information might be lost during
|
||||||
|
the process. Use it at your own risk.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Check if the model card is present on the remote repo
|
||||||
|
model_card = ModelCard.load(repo_id, token=token, ignore_metadata_errors=ignore_metadata_errors)
|
||||||
|
except EntryNotFoundError:
|
||||||
|
# Otherwise create a simple model card from template
|
||||||
|
model_description = "This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated."
|
||||||
|
card_data = ModelCardData(tags=[] if tags is None else tags, library_name="transformers")
|
||||||
|
model_card = ModelCard.from_template(card_data, model_description=model_description)
|
||||||
|
|
||||||
|
if tags is not None:
|
||||||
|
for model_tag in tags:
|
||||||
|
if model_tag not in model_card.data.tags:
|
||||||
|
model_card.data.tags.append(model_tag)
|
||||||
|
|
||||||
|
return model_card
|
||||||
|
|
||||||
|
|
||||||
def clean_files_for(file):
|
def clean_files_for(file):
|
||||||
"""
|
"""
|
||||||
Remove, if they exist, file, file.json and file.lock
|
Remove, if they exist, file, file.json and file.lock
|
||||||
|
|||||||
@@ -1435,6 +1435,11 @@ class ModelPushToHubTester(unittest.TestCase):
|
|||||||
except HTTPError:
|
except HTTPError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
delete_repo(token=cls._token, repo_id="test-dynamic-model-with-tags")
|
||||||
|
except HTTPError:
|
||||||
|
pass
|
||||||
|
|
||||||
@unittest.skip("This test is flaky")
|
@unittest.skip("This test is flaky")
|
||||||
def test_push_to_hub(self):
|
def test_push_to_hub(self):
|
||||||
config = BertConfig(
|
config = BertConfig(
|
||||||
@@ -1522,6 +1527,28 @@ The commit description supports markdown synthax see:
|
|||||||
new_model = AutoModel.from_config(config, trust_remote_code=True)
|
new_model = AutoModel.from_config(config, trust_remote_code=True)
|
||||||
self.assertEqual(new_model.__class__.__name__, "CustomModel")
|
self.assertEqual(new_model.__class__.__name__, "CustomModel")
|
||||||
|
|
||||||
|
def test_push_to_hub_with_tags(self):
|
||||||
|
from huggingface_hub import ModelCard
|
||||||
|
|
||||||
|
new_tags = ["tag-1", "tag-2"]
|
||||||
|
|
||||||
|
CustomConfig.register_for_auto_class()
|
||||||
|
CustomModel.register_for_auto_class()
|
||||||
|
|
||||||
|
config = CustomConfig(hidden_size=32)
|
||||||
|
model = CustomModel(config)
|
||||||
|
|
||||||
|
self.assertTrue(model.model_tags is None)
|
||||||
|
|
||||||
|
model.add_model_tags(new_tags)
|
||||||
|
|
||||||
|
self.assertTrue(model.model_tags == new_tags)
|
||||||
|
|
||||||
|
model.push_to_hub("test-dynamic-model-with-tags", token=self._token)
|
||||||
|
|
||||||
|
loaded_model_card = ModelCard.load(f"{USER}/test-dynamic-model-with-tags")
|
||||||
|
self.assertEqual(loaded_model_card.data.tags, new_tags)
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class AttentionMaskTester(unittest.TestCase):
|
class AttentionMaskTester(unittest.TestCase):
|
||||||
|
|||||||
Reference in New Issue
Block a user