Refactor internals for Trainer push_to_hub (#13486)
This commit is contained in:
@@ -2238,3 +2238,13 @@ class PushToHubMixin:
|
|||||||
commit_message = "add model"
|
commit_message = "add model"
|
||||||
|
|
||||||
return repo.push_to_hub(commit_message=commit_message)
|
return repo.push_to_hub(commit_message=commit_message)
|
||||||
|
|
||||||
|
|
||||||
|
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
|
||||||
|
if token is None:
|
||||||
|
token = HfFolder.get_token()
|
||||||
|
if organization is None:
|
||||||
|
username = HfApi().whoami(token)["name"]
|
||||||
|
return f"{username}/{model_id}"
|
||||||
|
else:
|
||||||
|
return f"{organization}/{model_id}"
|
||||||
|
|||||||
@@ -51,6 +51,8 @@ from torch import nn
|
|||||||
from torch.utils.data import DataLoader, Dataset, IterableDataset, RandomSampler, SequentialSampler
|
from torch.utils.data import DataLoader, Dataset, IterableDataset, RandomSampler, SequentialSampler
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
|
|
||||||
|
from huggingface_hub import Repository
|
||||||
|
|
||||||
from . import __version__
|
from . import __version__
|
||||||
from .configuration_utils import PretrainedConfig
|
from .configuration_utils import PretrainedConfig
|
||||||
from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
|
from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
|
||||||
@@ -60,7 +62,7 @@ from .dependency_versions_check import dep_version_check
|
|||||||
from .file_utils import (
|
from .file_utils import (
|
||||||
CONFIG_NAME,
|
CONFIG_NAME,
|
||||||
WEIGHTS_NAME,
|
WEIGHTS_NAME,
|
||||||
PushToHubMixin,
|
get_full_repo_name,
|
||||||
is_apex_available,
|
is_apex_available,
|
||||||
is_datasets_available,
|
is_datasets_available,
|
||||||
is_in_notebook,
|
is_in_notebook,
|
||||||
@@ -2478,15 +2480,17 @@ class Trainer:
|
|||||||
"""
|
"""
|
||||||
if not self.args.should_save:
|
if not self.args.should_save:
|
||||||
return
|
return
|
||||||
use_auth_token = True if self.args.push_to_hub_token is None else self.args.push_to_hub_token
|
use_auth_token = True if self.args.hub_token is None else self.args.hub_token
|
||||||
repo_url = PushToHubMixin._get_repo_url_from_name(
|
if self.args.hub_model_id is None:
|
||||||
self.args.push_to_hub_model_id,
|
repo_name = get_full_repo_name(Path(self.args.output_dir).name, token=self.args.hub_token)
|
||||||
organization=self.args.push_to_hub_organization,
|
else:
|
||||||
|
repo_name = self.args.hub_model_id
|
||||||
|
|
||||||
|
self.repo = Repository(
|
||||||
|
self.args.output_dir,
|
||||||
|
clone_from=repo_name,
|
||||||
use_auth_token=use_auth_token,
|
use_auth_token=use_auth_token,
|
||||||
)
|
)
|
||||||
self.repo = PushToHubMixin._create_or_get_repo(
|
|
||||||
self.args.output_dir, repo_url=repo_url, use_auth_token=use_auth_token
|
|
||||||
)
|
|
||||||
|
|
||||||
# By default, ignore the checkpoint folders
|
# By default, ignore the checkpoint folders
|
||||||
if not os.path.exists(os.path.join(self.args.output_dir, ".gitignore")):
|
if not os.path.exists(os.path.join(self.args.output_dir, ".gitignore")):
|
||||||
@@ -2523,7 +2527,7 @@ class Trainer:
|
|||||||
|
|
||||||
def push_to_hub(self, commit_message: Optional[str] = "add model", **kwargs) -> str:
|
def push_to_hub(self, commit_message: Optional[str] = "add model", **kwargs) -> str:
|
||||||
"""
|
"""
|
||||||
Upload `self.model` and `self.tokenizer` to the 🤗 model hub on the repo `self.args.push_to_hub_model_id`.
|
Upload `self.model` and `self.tokenizer` to the 🤗 model hub on the repo `self.args.hub_model_id`.
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
commit_message (:obj:`str`, `optional`, defaults to :obj:`"add model"`):
|
commit_message (:obj:`str`, `optional`, defaults to :obj:`"add model"`):
|
||||||
@@ -2536,7 +2540,11 @@ class Trainer:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
if self.args.should_save:
|
if self.args.should_save:
|
||||||
self.create_model_card(model_name=self.args.push_to_hub_model_id, **kwargs)
|
if self.args.hub_model_id is None:
|
||||||
|
model_name = Path(self.args.output_dir).name
|
||||||
|
else:
|
||||||
|
model_name = self.args.hub_model_id.split("/")[-1]
|
||||||
|
self.create_model_card(model_name=model_name, **kwargs)
|
||||||
# Needs to be executed on all processes for TPU training, but will only save on the processed determined by
|
# Needs to be executed on all processes for TPU training, but will only save on the processed determined by
|
||||||
# self.args.should_save.
|
# self.args.should_save.
|
||||||
self.save_model()
|
self.save_model()
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ from typing import Any, Dict, List, Optional
|
|||||||
from .debug_utils import DebugOption
|
from .debug_utils import DebugOption
|
||||||
from .file_utils import (
|
from .file_utils import (
|
||||||
cached_property,
|
cached_property,
|
||||||
|
get_full_repo_name,
|
||||||
is_sagemaker_dp_enabled,
|
is_sagemaker_dp_enabled,
|
||||||
is_sagemaker_mp_enabled,
|
is_sagemaker_mp_enabled,
|
||||||
is_torch_available,
|
is_torch_available,
|
||||||
@@ -335,12 +336,14 @@ class TrainingArguments:
|
|||||||
:class:`~transformers.Trainer`, it's intended to be used by your training/evaluation scripts instead. See
|
:class:`~transformers.Trainer`, it's intended to be used by your training/evaluation scripts instead. See
|
||||||
the `example scripts <https://github.com/huggingface/transformers/tree/master/examples>`__ for more
|
the `example scripts <https://github.com/huggingface/transformers/tree/master/examples>`__ for more
|
||||||
details.
|
details.
|
||||||
push_to_hub_model_id (:obj:`str`, `optional`):
|
hub_model_id (:obj:`str`, `optional`):
|
||||||
The name of the repository to which push the :class:`~transformers.Trainer` when :obj:`push_to_hub=True`.
|
The name of the repository to keep in sync with the local `output_dir`. Should be the whole repository
|
||||||
Will default to the name of :obj:`output_dir`.
|
name, for instance :obj:`"user_name/model"`, which allows you to push to an organization you are a member
|
||||||
push_to_hub_organization (:obj:`str`, `optional`):
|
of with :obj:`"organization_name/model"`.
|
||||||
The name of the organization in with to which push the :class:`~transformers.Trainer`.
|
|
||||||
push_to_hub_token (:obj:`str`, `optional`):
|
Will default to :obj:`user_name/output_dir_name` with `output_dir_name` being the name of
|
||||||
|
:obj:`output_dir`.
|
||||||
|
hub_token (:obj:`str`, `optional`):
|
||||||
The token to use to push the model to the Hub. Will default to the token in the cache folder obtained with
|
The token to use to push the model to the Hub. Will default to the token in the cache folder obtained with
|
||||||
:obj:`huggingface-cli login`.
|
:obj:`huggingface-cli login`.
|
||||||
"""
|
"""
|
||||||
@@ -612,6 +615,11 @@ class TrainingArguments:
|
|||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "The path to a folder with a valid checkpoint for your model."},
|
metadata={"help": "The path to a folder with a valid checkpoint for your model."},
|
||||||
)
|
)
|
||||||
|
hub_model_id: str = field(
|
||||||
|
default=None, metadata={"help": "The name of the repository to keep in sync with the local `output_dir`."}
|
||||||
|
)
|
||||||
|
hub_token: str = field(default=None, metadata={"help": "The token to use to push to the Model Hub."})
|
||||||
|
# Deprecated arguments
|
||||||
push_to_hub_model_id: str = field(
|
push_to_hub_model_id: str = field(
|
||||||
default=None, metadata={"help": "The name of the repository to which push the `Trainer`."}
|
default=None, metadata={"help": "The name of the repository to which push the `Trainer`."}
|
||||||
)
|
)
|
||||||
@@ -761,8 +769,40 @@ class TrainingArguments:
|
|||||||
self.hf_deepspeed_config = HfTrainerDeepSpeedConfig(self.deepspeed)
|
self.hf_deepspeed_config = HfTrainerDeepSpeedConfig(self.deepspeed)
|
||||||
self.hf_deepspeed_config.trainer_config_process(self)
|
self.hf_deepspeed_config.trainer_config_process(self)
|
||||||
|
|
||||||
if self.push_to_hub_model_id is None:
|
if self.push_to_hub_token is not None:
|
||||||
self.push_to_hub_model_id = Path(self.output_dir).name
|
warnings.warn(
|
||||||
|
"`--push_to_hub_token` is deprecated and will be removed in version 5 of 🤗 Transformers. Use "
|
||||||
|
"`--hub_token` instead.",
|
||||||
|
FutureWarning,
|
||||||
|
)
|
||||||
|
self.hub_token = self.push_to_hub_token
|
||||||
|
|
||||||
|
if self.push_to_hub_model_id is not None:
|
||||||
|
self.hub_model_id = get_full_repo_name(
|
||||||
|
self.push_to_hub_model_id, organization=self.push_to_hub_organization, token=self.hub_token
|
||||||
|
)
|
||||||
|
if self.push_to_hub_organization is not None:
|
||||||
|
warnings.warn(
|
||||||
|
"`--push_to_hub_model_id` and `--push_to_hub_organization` are deprecated and will be removed in "
|
||||||
|
"version 5 of 🤗 Transformers. Use `--hub_model_id` instead and pass the full repo name to this "
|
||||||
|
f"argument (in this case {self.hub_model_id}).",
|
||||||
|
FutureWarning,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
warnings.warn(
|
||||||
|
"`--push_to_hub_model_id` is deprecated and will be removed in version 5 of 🤗 Transformers. Use "
|
||||||
|
"`--hub_model_id` instead and pass the full repo name to this argument (in this case "
|
||||||
|
f"{self.hub_model_id}).",
|
||||||
|
FutureWarning,
|
||||||
|
)
|
||||||
|
elif self.push_to_hub_organization is not None:
|
||||||
|
self.hub_model_id = f"{self.push_to_hub_organization}/{Path(self.output_dir).name}"
|
||||||
|
warnings.warn(
|
||||||
|
"`--push_to_hub_organization` is deprecated and will be removed in version 5 of 🤗 Transformers. Use "
|
||||||
|
"`--hub_model_id` instead and pass the full repo name to this argument (in this case "
|
||||||
|
f"{self.hub_model_id}).",
|
||||||
|
FutureWarning,
|
||||||
|
)
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
self_as_dict = asdict(self)
|
self_as_dict = asdict(self)
|
||||||
|
|||||||
@@ -1299,7 +1299,7 @@ class TrainerIntegrationWithHubTester(unittest.TestCase):
|
|||||||
trainer = get_regression_trainer(
|
trainer = get_regression_trainer(
|
||||||
output_dir=os.path.join(tmp_dir, "test-trainer"),
|
output_dir=os.path.join(tmp_dir, "test-trainer"),
|
||||||
push_to_hub=True,
|
push_to_hub=True,
|
||||||
push_to_hub_token=self._token,
|
hub_token=self._token,
|
||||||
)
|
)
|
||||||
url = trainer.push_to_hub()
|
url = trainer.push_to_hub()
|
||||||
|
|
||||||
@@ -1321,8 +1321,8 @@ class TrainerIntegrationWithHubTester(unittest.TestCase):
|
|||||||
trainer = get_regression_trainer(
|
trainer = get_regression_trainer(
|
||||||
output_dir=os.path.join(tmp_dir, "test-trainer-org"),
|
output_dir=os.path.join(tmp_dir, "test-trainer-org"),
|
||||||
push_to_hub=True,
|
push_to_hub=True,
|
||||||
push_to_hub_organization="valid_org",
|
hub_model_id="valid_org/test-trainer-org",
|
||||||
push_to_hub_token=self._token,
|
hub_token=self._token,
|
||||||
)
|
)
|
||||||
url = trainer.push_to_hub()
|
url = trainer.push_to_hub()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user