diff --git a/src/transformers/file_utils.py b/src/transformers/file_utils.py index ab5790da6d..434cf454b8 100644 --- a/src/transformers/file_utils.py +++ b/src/transformers/file_utils.py @@ -2238,3 +2238,13 @@ class PushToHubMixin: commit_message = "add model" return repo.push_to_hub(commit_message=commit_message) + + +def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): + if token is None: + token = HfFolder.get_token() + if organization is None: + username = HfApi().whoami(token)["name"] + return f"{username}/{model_id}" + else: + return f"{organization}/{model_id}" diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index f59cec7765..77471ea137 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -51,6 +51,8 @@ from torch import nn from torch.utils.data import DataLoader, Dataset, IterableDataset, RandomSampler, SequentialSampler from torch.utils.data.distributed import DistributedSampler +from huggingface_hub import Repository + from . import __version__ from .configuration_utils import PretrainedConfig from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator @@ -60,7 +62,7 @@ from .dependency_versions_check import dep_version_check from .file_utils import ( CONFIG_NAME, WEIGHTS_NAME, - PushToHubMixin, + get_full_repo_name, is_apex_available, is_datasets_available, is_in_notebook, @@ -2478,15 +2480,17 @@ class Trainer: """ if not self.args.should_save: return - use_auth_token = True if self.args.push_to_hub_token is None else self.args.push_to_hub_token - repo_url = PushToHubMixin._get_repo_url_from_name( - self.args.push_to_hub_model_id, - organization=self.args.push_to_hub_organization, + use_auth_token = True if self.args.hub_token is None else self.args.hub_token + if self.args.hub_model_id is None: + repo_name = get_full_repo_name(Path(self.args.output_dir).name, token=self.args.hub_token) + else: + repo_name = self.args.hub_model_id + + self.repo = Repository( + self.args.output_dir, + clone_from=repo_name, use_auth_token=use_auth_token, ) - self.repo = PushToHubMixin._create_or_get_repo( - self.args.output_dir, repo_url=repo_url, use_auth_token=use_auth_token - ) # By default, ignore the checkpoint folders if not os.path.exists(os.path.join(self.args.output_dir, ".gitignore")): @@ -2523,7 +2527,7 @@ class Trainer: def push_to_hub(self, commit_message: Optional[str] = "add model", **kwargs) -> str: """ - Upload `self.model` and `self.tokenizer` to the 🤗 model hub on the repo `self.args.push_to_hub_model_id`. + Upload `self.model` and `self.tokenizer` to the 🤗 model hub on the repo `self.args.hub_model_id`. Parameters: commit_message (:obj:`str`, `optional`, defaults to :obj:`"add model"`): @@ -2536,7 +2540,11 @@ class Trainer: """ if self.args.should_save: - self.create_model_card(model_name=self.args.push_to_hub_model_id, **kwargs) + if self.args.hub_model_id is None: + model_name = Path(self.args.output_dir).name + else: + model_name = self.args.hub_model_id.split("/")[-1] + self.create_model_card(model_name=model_name, **kwargs) # Needs to be executed on all processes for TPU training, but will only save on the processed determined by # self.args.should_save. self.save_model() diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 5f6b877bb0..9829272978 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -25,6 +25,7 @@ from typing import Any, Dict, List, Optional from .debug_utils import DebugOption from .file_utils import ( cached_property, + get_full_repo_name, is_sagemaker_dp_enabled, is_sagemaker_mp_enabled, is_torch_available, @@ -335,12 +336,14 @@ class TrainingArguments: :class:`~transformers.Trainer`, it's intended to be used by your training/evaluation scripts instead. See the `example scripts `__ for more details. - push_to_hub_model_id (:obj:`str`, `optional`): - The name of the repository to which push the :class:`~transformers.Trainer` when :obj:`push_to_hub=True`. - Will default to the name of :obj:`output_dir`. - push_to_hub_organization (:obj:`str`, `optional`): - The name of the organization in with to which push the :class:`~transformers.Trainer`. - push_to_hub_token (:obj:`str`, `optional`): + hub_model_id (:obj:`str`, `optional`): + The name of the repository to keep in sync with the local `output_dir`. Should be the whole repository + name, for instance :obj:`"user_name/model"`, which allows you to push to an organization you are a member + of with :obj:`"organization_name/model"`. + + Will default to :obj:`user_name/output_dir_name` with `output_dir_name` being the name of + :obj:`output_dir`. + hub_token (:obj:`str`, `optional`): The token to use to push the model to the Hub. Will default to the token in the cache folder obtained with :obj:`huggingface-cli login`. """ @@ -612,6 +615,11 @@ class TrainingArguments: default=None, metadata={"help": "The path to a folder with a valid checkpoint for your model."}, ) + hub_model_id: str = field( + default=None, metadata={"help": "The name of the repository to keep in sync with the local `output_dir`."} + ) + hub_token: str = field(default=None, metadata={"help": "The token to use to push to the Model Hub."}) + # Deprecated arguments push_to_hub_model_id: str = field( default=None, metadata={"help": "The name of the repository to which push the `Trainer`."} ) @@ -761,8 +769,40 @@ class TrainingArguments: self.hf_deepspeed_config = HfTrainerDeepSpeedConfig(self.deepspeed) self.hf_deepspeed_config.trainer_config_process(self) - if self.push_to_hub_model_id is None: - self.push_to_hub_model_id = Path(self.output_dir).name + if self.push_to_hub_token is not None: + warnings.warn( + "`--push_to_hub_token` is deprecated and will be removed in version 5 of 🤗 Transformers. Use " + "`--hub_token` instead.", + FutureWarning, + ) + self.hub_token = self.push_to_hub_token + + if self.push_to_hub_model_id is not None: + self.hub_model_id = get_full_repo_name( + self.push_to_hub_model_id, organization=self.push_to_hub_organization, token=self.hub_token + ) + if self.push_to_hub_organization is not None: + warnings.warn( + "`--push_to_hub_model_id` and `--push_to_hub_organization` are deprecated and will be removed in " + "version 5 of 🤗 Transformers. Use `--hub_model_id` instead and pass the full repo name to this " + f"argument (in this case {self.hub_model_id}).", + FutureWarning, + ) + else: + warnings.warn( + "`--push_to_hub_model_id` is deprecated and will be removed in version 5 of 🤗 Transformers. Use " + "`--hub_model_id` instead and pass the full repo name to this argument (in this case " + f"{self.hub_model_id}).", + FutureWarning, + ) + elif self.push_to_hub_organization is not None: + self.hub_model_id = f"{self.push_to_hub_organization}/{Path(self.output_dir).name}" + warnings.warn( + "`--push_to_hub_organization` is deprecated and will be removed in version 5 of 🤗 Transformers. Use " + "`--hub_model_id` instead and pass the full repo name to this argument (in this case " + f"{self.hub_model_id}).", + FutureWarning, + ) def __str__(self): self_as_dict = asdict(self) diff --git a/tests/test_trainer.py b/tests/test_trainer.py index 97ad249e6f..04abf1d6ce 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -1299,7 +1299,7 @@ class TrainerIntegrationWithHubTester(unittest.TestCase): trainer = get_regression_trainer( output_dir=os.path.join(tmp_dir, "test-trainer"), push_to_hub=True, - push_to_hub_token=self._token, + hub_token=self._token, ) url = trainer.push_to_hub() @@ -1321,8 +1321,8 @@ class TrainerIntegrationWithHubTester(unittest.TestCase): trainer = get_regression_trainer( output_dir=os.path.join(tmp_dir, "test-trainer-org"), push_to_hub=True, - push_to_hub_organization="valid_org", - push_to_hub_token=self._token, + hub_model_id="valid_org/test-trainer-org", + hub_token=self._token, ) url = trainer.push_to_hub()