Refactor internals for Trainer push_to_hub (#13486)

This commit is contained in:
Sylvain Gugger
2021-09-09 13:04:37 -04:00
committed by GitHub
parent 3dd538c4d3
commit e59d4d0147
4 changed files with 79 additions and 21 deletions

View File

@@ -2238,3 +2238,13 @@ class PushToHubMixin:
commit_message = "add model" commit_message = "add model"
return repo.push_to_hub(commit_message=commit_message) return repo.push_to_hub(commit_message=commit_message)
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
if token is None:
token = HfFolder.get_token()
if organization is None:
username = HfApi().whoami(token)["name"]
return f"{username}/{model_id}"
else:
return f"{organization}/{model_id}"

View File

@@ -51,6 +51,8 @@ from torch import nn
from torch.utils.data import DataLoader, Dataset, IterableDataset, RandomSampler, SequentialSampler from torch.utils.data import DataLoader, Dataset, IterableDataset, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from huggingface_hub import Repository
from . import __version__ from . import __version__
from .configuration_utils import PretrainedConfig from .configuration_utils import PretrainedConfig
from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
@@ -60,7 +62,7 @@ from .dependency_versions_check import dep_version_check
from .file_utils import ( from .file_utils import (
CONFIG_NAME, CONFIG_NAME,
WEIGHTS_NAME, WEIGHTS_NAME,
PushToHubMixin, get_full_repo_name,
is_apex_available, is_apex_available,
is_datasets_available, is_datasets_available,
is_in_notebook, is_in_notebook,
@@ -2478,15 +2480,17 @@ class Trainer:
""" """
if not self.args.should_save: if not self.args.should_save:
return return
use_auth_token = True if self.args.push_to_hub_token is None else self.args.push_to_hub_token use_auth_token = True if self.args.hub_token is None else self.args.hub_token
repo_url = PushToHubMixin._get_repo_url_from_name( if self.args.hub_model_id is None:
self.args.push_to_hub_model_id, repo_name = get_full_repo_name(Path(self.args.output_dir).name, token=self.args.hub_token)
organization=self.args.push_to_hub_organization, else:
repo_name = self.args.hub_model_id
self.repo = Repository(
self.args.output_dir,
clone_from=repo_name,
use_auth_token=use_auth_token, use_auth_token=use_auth_token,
) )
self.repo = PushToHubMixin._create_or_get_repo(
self.args.output_dir, repo_url=repo_url, use_auth_token=use_auth_token
)
# By default, ignore the checkpoint folders # By default, ignore the checkpoint folders
if not os.path.exists(os.path.join(self.args.output_dir, ".gitignore")): if not os.path.exists(os.path.join(self.args.output_dir, ".gitignore")):
@@ -2523,7 +2527,7 @@ class Trainer:
def push_to_hub(self, commit_message: Optional[str] = "add model", **kwargs) -> str: def push_to_hub(self, commit_message: Optional[str] = "add model", **kwargs) -> str:
""" """
Upload `self.model` and `self.tokenizer` to the 🤗 model hub on the repo `self.args.push_to_hub_model_id`. Upload `self.model` and `self.tokenizer` to the 🤗 model hub on the repo `self.args.hub_model_id`.
Parameters: Parameters:
commit_message (:obj:`str`, `optional`, defaults to :obj:`"add model"`): commit_message (:obj:`str`, `optional`, defaults to :obj:`"add model"`):
@@ -2536,7 +2540,11 @@ class Trainer:
""" """
if self.args.should_save: if self.args.should_save:
self.create_model_card(model_name=self.args.push_to_hub_model_id, **kwargs) if self.args.hub_model_id is None:
model_name = Path(self.args.output_dir).name
else:
model_name = self.args.hub_model_id.split("/")[-1]
self.create_model_card(model_name=model_name, **kwargs)
# Needs to be executed on all processes for TPU training, but will only save on the processed determined by # Needs to be executed on all processes for TPU training, but will only save on the processed determined by
# self.args.should_save. # self.args.should_save.
self.save_model() self.save_model()

View File

@@ -25,6 +25,7 @@ from typing import Any, Dict, List, Optional
from .debug_utils import DebugOption from .debug_utils import DebugOption
from .file_utils import ( from .file_utils import (
cached_property, cached_property,
get_full_repo_name,
is_sagemaker_dp_enabled, is_sagemaker_dp_enabled,
is_sagemaker_mp_enabled, is_sagemaker_mp_enabled,
is_torch_available, is_torch_available,
@@ -335,12 +336,14 @@ class TrainingArguments:
:class:`~transformers.Trainer`, it's intended to be used by your training/evaluation scripts instead. See :class:`~transformers.Trainer`, it's intended to be used by your training/evaluation scripts instead. See
the `example scripts <https://github.com/huggingface/transformers/tree/master/examples>`__ for more the `example scripts <https://github.com/huggingface/transformers/tree/master/examples>`__ for more
details. details.
push_to_hub_model_id (:obj:`str`, `optional`): hub_model_id (:obj:`str`, `optional`):
The name of the repository to which push the :class:`~transformers.Trainer` when :obj:`push_to_hub=True`. The name of the repository to keep in sync with the local `output_dir`. Should be the whole repository
Will default to the name of :obj:`output_dir`. name, for instance :obj:`"user_name/model"`, which allows you to push to an organization you are a member
push_to_hub_organization (:obj:`str`, `optional`): of with :obj:`"organization_name/model"`.
The name of the organization in with to which push the :class:`~transformers.Trainer`.
push_to_hub_token (:obj:`str`, `optional`): Will default to :obj:`user_name/output_dir_name` with `output_dir_name` being the name of
:obj:`output_dir`.
hub_token (:obj:`str`, `optional`):
The token to use to push the model to the Hub. Will default to the token in the cache folder obtained with The token to use to push the model to the Hub. Will default to the token in the cache folder obtained with
:obj:`huggingface-cli login`. :obj:`huggingface-cli login`.
""" """
@@ -612,6 +615,11 @@ class TrainingArguments:
default=None, default=None,
metadata={"help": "The path to a folder with a valid checkpoint for your model."}, metadata={"help": "The path to a folder with a valid checkpoint for your model."},
) )
hub_model_id: str = field(
default=None, metadata={"help": "The name of the repository to keep in sync with the local `output_dir`."}
)
hub_token: str = field(default=None, metadata={"help": "The token to use to push to the Model Hub."})
# Deprecated arguments
push_to_hub_model_id: str = field( push_to_hub_model_id: str = field(
default=None, metadata={"help": "The name of the repository to which push the `Trainer`."} default=None, metadata={"help": "The name of the repository to which push the `Trainer`."}
) )
@@ -761,8 +769,40 @@ class TrainingArguments:
self.hf_deepspeed_config = HfTrainerDeepSpeedConfig(self.deepspeed) self.hf_deepspeed_config = HfTrainerDeepSpeedConfig(self.deepspeed)
self.hf_deepspeed_config.trainer_config_process(self) self.hf_deepspeed_config.trainer_config_process(self)
if self.push_to_hub_model_id is None: if self.push_to_hub_token is not None:
self.push_to_hub_model_id = Path(self.output_dir).name warnings.warn(
"`--push_to_hub_token` is deprecated and will be removed in version 5 of 🤗 Transformers. Use "
"`--hub_token` instead.",
FutureWarning,
)
self.hub_token = self.push_to_hub_token
if self.push_to_hub_model_id is not None:
self.hub_model_id = get_full_repo_name(
self.push_to_hub_model_id, organization=self.push_to_hub_organization, token=self.hub_token
)
if self.push_to_hub_organization is not None:
warnings.warn(
"`--push_to_hub_model_id` and `--push_to_hub_organization` are deprecated and will be removed in "
"version 5 of 🤗 Transformers. Use `--hub_model_id` instead and pass the full repo name to this "
f"argument (in this case {self.hub_model_id}).",
FutureWarning,
)
else:
warnings.warn(
"`--push_to_hub_model_id` is deprecated and will be removed in version 5 of 🤗 Transformers. Use "
"`--hub_model_id` instead and pass the full repo name to this argument (in this case "
f"{self.hub_model_id}).",
FutureWarning,
)
elif self.push_to_hub_organization is not None:
self.hub_model_id = f"{self.push_to_hub_organization}/{Path(self.output_dir).name}"
warnings.warn(
"`--push_to_hub_organization` is deprecated and will be removed in version 5 of 🤗 Transformers. Use "
"`--hub_model_id` instead and pass the full repo name to this argument (in this case "
f"{self.hub_model_id}).",
FutureWarning,
)
def __str__(self): def __str__(self):
self_as_dict = asdict(self) self_as_dict = asdict(self)

View File

@@ -1299,7 +1299,7 @@ class TrainerIntegrationWithHubTester(unittest.TestCase):
trainer = get_regression_trainer( trainer = get_regression_trainer(
output_dir=os.path.join(tmp_dir, "test-trainer"), output_dir=os.path.join(tmp_dir, "test-trainer"),
push_to_hub=True, push_to_hub=True,
push_to_hub_token=self._token, hub_token=self._token,
) )
url = trainer.push_to_hub() url = trainer.push_to_hub()
@@ -1321,8 +1321,8 @@ class TrainerIntegrationWithHubTester(unittest.TestCase):
trainer = get_regression_trainer( trainer = get_regression_trainer(
output_dir=os.path.join(tmp_dir, "test-trainer-org"), output_dir=os.path.join(tmp_dir, "test-trainer-org"),
push_to_hub=True, push_to_hub=True,
push_to_hub_organization="valid_org", hub_model_id="valid_org/test-trainer-org",
push_to_hub_token=self._token, hub_token=self._token,
) )
url = trainer.push_to_hub() url = trainer.push_to_hub()