CLI: use hub's create_commit (#17755)
* use create_commit * better commit message and description * touch setup.py to trigger cache update * add hub version gating
This commit is contained in:
2
.github/workflows/add-model-like.yml
vendored
2
.github/workflows/add-model-like.yml
vendored
@@ -27,7 +27,7 @@ jobs:
|
|||||||
id: cache
|
id: cache
|
||||||
with:
|
with:
|
||||||
path: ~/venv/
|
path: ~/venv/
|
||||||
key: v3-tests_model_like-${{ hashFiles('setup.py') }}
|
key: v4-tests_model_like-${{ hashFiles('setup.py') }}
|
||||||
|
|
||||||
- name: Create virtual environment on cache miss
|
- name: Create virtual environment on cache miss
|
||||||
if: steps.cache.outputs.cache-hit != 'true'
|
if: steps.cache.outputs.cache-hit != 'true'
|
||||||
|
|||||||
2
.github/workflows/model-templates.yml
vendored
2
.github/workflows/model-templates.yml
vendored
@@ -21,7 +21,7 @@ jobs:
|
|||||||
id: cache
|
id: cache
|
||||||
with:
|
with:
|
||||||
path: ~/venv/
|
path: ~/venv/
|
||||||
key: v3-tests_templates-${{ hashFiles('setup.py') }}
|
key: v4-tests_templates-${{ hashFiles('setup.py') }}
|
||||||
|
|
||||||
- name: Create virtual environment on cache miss
|
- name: Create virtual environment on cache miss
|
||||||
if: steps.cache.outputs.cache-hit != 'true'
|
if: steps.cache.outputs.cache-hit != 'true'
|
||||||
|
|||||||
2
.github/workflows/update_metdata.yml
vendored
2
.github/workflows/update_metdata.yml
vendored
@@ -21,7 +21,7 @@ jobs:
|
|||||||
id: cache
|
id: cache
|
||||||
with:
|
with:
|
||||||
path: ~/venv/
|
path: ~/venv/
|
||||||
key: v2-metadata-${{ hashFiles('setup.py') }}
|
key: v3-metadata-${{ hashFiles('setup.py') }}
|
||||||
|
|
||||||
- name: Create virtual environment on cache miss
|
- name: Create virtual environment on cache miss
|
||||||
if: steps.cache.outputs.cache-hit != 'true'
|
if: steps.cache.outputs.cache-hit != 'true'
|
||||||
|
|||||||
@@ -18,8 +18,9 @@ from importlib import import_module
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
|
from packaging import version
|
||||||
|
|
||||||
from huggingface_hub import Repository, upload_file
|
import huggingface_hub
|
||||||
|
|
||||||
from .. import AutoConfig, AutoFeatureExtractor, AutoTokenizer, is_tf_available, is_torch_available
|
from .. import AutoConfig, AutoFeatureExtractor, AutoTokenizer, is_tf_available, is_torch_available
|
||||||
from ..utils import logging
|
from ..utils import logging
|
||||||
@@ -45,7 +46,9 @@ def convert_command_factory(args: Namespace):
|
|||||||
|
|
||||||
Returns: ServeCommand
|
Returns: ServeCommand
|
||||||
"""
|
"""
|
||||||
return PTtoTFCommand(args.model_name, args.local_dir, args.new_weights, args.no_pr, args.push)
|
return PTtoTFCommand(
|
||||||
|
args.model_name, args.local_dir, args.new_weights, args.no_pr, args.push, args.extra_commit_description
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class PTtoTFCommand(BaseTransformersCLICommand):
|
class PTtoTFCommand(BaseTransformersCLICommand):
|
||||||
@@ -89,6 +92,12 @@ class PTtoTFCommand(BaseTransformersCLICommand):
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Optional flag to push the weights directly to `main` (requires permissions)",
|
help="Optional flag to push the weights directly to `main` (requires permissions)",
|
||||||
)
|
)
|
||||||
|
train_parser.add_argument(
|
||||||
|
"--extra-commit-description",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help="Optional additional commit description to use when opening a PR (e.g. to tag the owner).",
|
||||||
|
)
|
||||||
train_parser.set_defaults(func=convert_command_factory)
|
train_parser.set_defaults(func=convert_command_factory)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -134,13 +143,23 @@ class PTtoTFCommand(BaseTransformersCLICommand):
|
|||||||
|
|
||||||
return _find_pt_tf_differences(pt_outputs, tf_outputs, {})
|
return _find_pt_tf_differences(pt_outputs, tf_outputs, {})
|
||||||
|
|
||||||
def __init__(self, model_name: str, local_dir: str, new_weights: bool, no_pr: bool, push: bool, *args):
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_name: str,
|
||||||
|
local_dir: str,
|
||||||
|
new_weights: bool,
|
||||||
|
no_pr: bool,
|
||||||
|
push: bool,
|
||||||
|
extra_commit_description: str,
|
||||||
|
*args
|
||||||
|
):
|
||||||
self._logger = logging.get_logger("transformers-cli/pt_to_tf")
|
self._logger = logging.get_logger("transformers-cli/pt_to_tf")
|
||||||
self._model_name = model_name
|
self._model_name = model_name
|
||||||
self._local_dir = local_dir if local_dir else os.path.join("/tmp", model_name)
|
self._local_dir = local_dir if local_dir else os.path.join("/tmp", model_name)
|
||||||
self._new_weights = new_weights
|
self._new_weights = new_weights
|
||||||
self._no_pr = no_pr
|
self._no_pr = no_pr
|
||||||
self._push = push
|
self._push = push
|
||||||
|
self._extra_commit_description = extra_commit_description
|
||||||
|
|
||||||
def get_text_inputs(self):
|
def get_text_inputs(self):
|
||||||
tokenizer = AutoTokenizer.from_pretrained(self._local_dir)
|
tokenizer = AutoTokenizer.from_pretrained(self._local_dir)
|
||||||
@@ -170,10 +189,17 @@ class PTtoTFCommand(BaseTransformersCLICommand):
|
|||||||
return pt_input, tf_input
|
return pt_input, tf_input
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
|
if version.parse(huggingface_hub.__version__) < version.parse("0.8.1"):
|
||||||
|
raise ImportError(
|
||||||
|
"The huggingface_hub version must be >= 0.8.1 to use this command. Please update your huggingface_hub"
|
||||||
|
" installation."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
from huggingface_hub import Repository, create_commit
|
||||||
|
from huggingface_hub._commit_api import CommitOperationAdd
|
||||||
|
|
||||||
# Fetch remote data
|
# Fetch remote data
|
||||||
# TODO: implement a solution to pull a specific PR/commit, so we can use this CLI to validate pushes.
|
|
||||||
repo = Repository(local_dir=self._local_dir, clone_from=self._model_name)
|
repo = Repository(local_dir=self._local_dir, clone_from=self._model_name)
|
||||||
repo.git_pull() # in case the repo already exists locally, but with an older commit
|
|
||||||
|
|
||||||
# Load config and get the appropriate architecture -- the latter is needed to convert the head's weights
|
# Load config and get the appropriate architecture -- the latter is needed to convert the head's weights
|
||||||
config = AutoConfig.from_pretrained(self._local_dir)
|
config = AutoConfig.from_pretrained(self._local_dir)
|
||||||
@@ -240,32 +266,29 @@ class PTtoTFCommand(BaseTransformersCLICommand):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
commit_message = "Update TF weights" if self._new_weights else "Add TF weights"
|
||||||
if self._push:
|
if self._push:
|
||||||
repo.git_add(auto_lfs_track=True)
|
repo.git_add(auto_lfs_track=True)
|
||||||
repo.git_commit("Add TF weights")
|
repo.git_commit(commit_message)
|
||||||
repo.git_push(blocking=True) # this prints a progress bar with the upload
|
repo.git_push(blocking=True) # this prints a progress bar with the upload
|
||||||
self._logger.warn(f"TF weights pushed into {self._model_name}")
|
self._logger.warn(f"TF weights pushed into {self._model_name}")
|
||||||
elif not self._no_pr:
|
elif not self._no_pr:
|
||||||
# TODO: remove try/except when the upload to PR feature is released
|
|
||||||
# (https://github.com/huggingface/huggingface_hub/pull/884)
|
|
||||||
try:
|
|
||||||
self._logger.warn("Uploading the weights into a new PR...")
|
self._logger.warn("Uploading the weights into a new PR...")
|
||||||
hub_pr_url = upload_file(
|
commit_descrition = (
|
||||||
path_or_fileobj=tf_weights_path,
|
"Model converted by the [`transformers`' `pt_to_tf`"
|
||||||
path_in_repo=TF_WEIGHTS_NAME,
|
" CLI](https://github.com/huggingface/transformers/blob/main/src/transformers/commands/pt_to_tf.py)."
|
||||||
repo_id=self._model_name,
|
"\n\nAll converted model outputs and hidden layers were validated against its Pytorch counterpart."
|
||||||
create_pr=True,
|
f" Maximum crossload output difference={max_crossload_diff:.3e}; Maximum converted output"
|
||||||
pr_commit_summary="Add TF weights",
|
|
||||||
pr_commit_description=(
|
|
||||||
"Model converted by the `transformers`' `pt_to_tf` CLI -- all converted model outputs and"
|
|
||||||
" hidden layers were validated against its Pytorch counterpart. Maximum crossload output"
|
|
||||||
f" difference={max_crossload_diff:.3e}; Maximum converted output"
|
|
||||||
f" difference={max_conversion_diff:.3e}."
|
f" difference={max_conversion_diff:.3e}."
|
||||||
),
|
)
|
||||||
|
if self._extra_commit_description:
|
||||||
|
commit_descrition += "\n\n" + self._extra_commit_description
|
||||||
|
hub_pr_url = create_commit(
|
||||||
|
repo_id=self._model_name,
|
||||||
|
operations=[CommitOperationAdd(path_in_repo=TF_WEIGHTS_NAME, path_or_fileobj=tf_weights_path)],
|
||||||
|
commit_message=commit_message,
|
||||||
|
commit_description=commit_descrition,
|
||||||
|
repo_type="model",
|
||||||
|
create_pr=True,
|
||||||
)
|
)
|
||||||
self._logger.warn(f"PR open in {hub_pr_url}")
|
self._logger.warn(f"PR open in {hub_pr_url}")
|
||||||
except TypeError:
|
|
||||||
self._logger.warn(
|
|
||||||
f"You can now open a PR in https://huggingface.co/{self._model_name}/discussions, manually"
|
|
||||||
f" uploading the file in {tf_weights_path}"
|
|
||||||
)
|
|
||||||
|
|||||||
Reference in New Issue
Block a user