@@ -1,137 +0,0 @@
|
|||||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
import os
|
|
||||||
from argparse import ArgumentParser, Namespace
|
|
||||||
|
|
||||||
from ..utils import logging
|
|
||||||
from . import BaseTransformersCLICommand
|
|
||||||
|
|
||||||
|
|
||||||
MAX_ERROR = 5e-5 # larger error tolerance than in our internal tests, to avoid flaky user-facing errors
|
|
||||||
|
|
||||||
|
|
||||||
def convert_command_factory(args: Namespace):
|
|
||||||
"""
|
|
||||||
Factory function used to convert a model PyTorch checkpoint in a TensorFlow 2 checkpoint.
|
|
||||||
|
|
||||||
Returns: ServeCommand
|
|
||||||
"""
|
|
||||||
return PTtoTFCommand(
|
|
||||||
args.model_name,
|
|
||||||
args.local_dir,
|
|
||||||
args.max_error,
|
|
||||||
args.new_weights,
|
|
||||||
args.no_pr,
|
|
||||||
args.push,
|
|
||||||
args.extra_commit_description,
|
|
||||||
args.override_model_class,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class PTtoTFCommand(BaseTransformersCLICommand):
|
|
||||||
@staticmethod
|
|
||||||
def register_subcommand(parser: ArgumentParser):
|
|
||||||
"""
|
|
||||||
Register this command to argparse so it's available for the transformer-cli
|
|
||||||
|
|
||||||
Args:
|
|
||||||
parser: Root parser to register command-specific arguments
|
|
||||||
"""
|
|
||||||
train_parser = parser.add_parser(
|
|
||||||
"pt-to-tf",
|
|
||||||
help=(
|
|
||||||
"CLI tool to run convert a transformers model from a PyTorch checkpoint to a TensorFlow checkpoint."
|
|
||||||
" Can also be used to validate existing weights without opening PRs, with --no-pr."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
train_parser.add_argument(
|
|
||||||
"--model-name",
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
help="The model name, including owner/organization, as seen on the hub.",
|
|
||||||
)
|
|
||||||
train_parser.add_argument(
|
|
||||||
"--local-dir",
|
|
||||||
type=str,
|
|
||||||
default="",
|
|
||||||
help="Optional local directory of the model repository. Defaults to /tmp/{model_name}",
|
|
||||||
)
|
|
||||||
train_parser.add_argument(
|
|
||||||
"--max-error",
|
|
||||||
type=float,
|
|
||||||
default=MAX_ERROR,
|
|
||||||
help=(
|
|
||||||
f"Maximum error tolerance. Defaults to {MAX_ERROR}. This flag should be avoided, use at your own risk."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
train_parser.add_argument(
|
|
||||||
"--new-weights",
|
|
||||||
action="store_true",
|
|
||||||
help="Optional flag to create new TensorFlow weights, even if they already exist.",
|
|
||||||
)
|
|
||||||
train_parser.add_argument(
|
|
||||||
"--no-pr", action="store_true", help="Optional flag to NOT open a PR with converted weights."
|
|
||||||
)
|
|
||||||
train_parser.add_argument(
|
|
||||||
"--push",
|
|
||||||
action="store_true",
|
|
||||||
help="Optional flag to push the weights directly to `main` (requires permissions)",
|
|
||||||
)
|
|
||||||
train_parser.add_argument(
|
|
||||||
"--extra-commit-description",
|
|
||||||
type=str,
|
|
||||||
default="",
|
|
||||||
help="Optional additional commit description to use when opening a PR (e.g. to tag the owner).",
|
|
||||||
)
|
|
||||||
train_parser.add_argument(
|
|
||||||
"--override-model-class",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help="If you think you know better than the auto-detector, you can specify the model class here. "
|
|
||||||
"Can be either an AutoModel class or a specific model class like BertForSequenceClassification.",
|
|
||||||
)
|
|
||||||
train_parser.set_defaults(func=convert_command_factory)
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_name: str,
|
|
||||||
local_dir: str,
|
|
||||||
max_error: float,
|
|
||||||
new_weights: bool,
|
|
||||||
no_pr: bool,
|
|
||||||
push: bool,
|
|
||||||
extra_commit_description: str,
|
|
||||||
override_model_class: str,
|
|
||||||
*args,
|
|
||||||
):
|
|
||||||
self._logger = logging.get_logger("transformers-cli/pt_to_tf")
|
|
||||||
self._model_name = model_name
|
|
||||||
self._local_dir = local_dir if local_dir else os.path.join("/tmp", model_name)
|
|
||||||
self._max_error = max_error
|
|
||||||
self._new_weights = new_weights
|
|
||||||
self._no_pr = no_pr
|
|
||||||
self._push = push
|
|
||||||
self._extra_commit_description = extra_commit_description
|
|
||||||
self._override_model_class = override_model_class
|
|
||||||
|
|
||||||
def run(self):
|
|
||||||
# TODO (joao): delete file in v4.47
|
|
||||||
raise NotImplementedError(
|
|
||||||
"\n\nConverting PyTorch weights to TensorFlow weights was removed in v4.43. "
|
|
||||||
"Instead, we recommend that you convert PyTorch weights to Safetensors, an improved "
|
|
||||||
"format that can be loaded by any framework, including TensorFlow. For more information, "
|
|
||||||
"please see the Safetensors conversion guide: "
|
|
||||||
"https://huggingface.co/docs/safetensors/en/convert-weights\n\n"
|
|
||||||
)
|
|
||||||
@@ -20,7 +20,6 @@ from .convert import ConvertCommand
|
|||||||
from .download import DownloadCommand
|
from .download import DownloadCommand
|
||||||
from .env import EnvironmentCommand
|
from .env import EnvironmentCommand
|
||||||
from .lfs import LfsCommands
|
from .lfs import LfsCommands
|
||||||
from .pt_to_tf import PTtoTFCommand
|
|
||||||
from .run import RunCommand
|
from .run import RunCommand
|
||||||
from .serving import ServeCommand
|
from .serving import ServeCommand
|
||||||
from .user import UserCommands
|
from .user import UserCommands
|
||||||
@@ -39,7 +38,6 @@ def main():
|
|||||||
UserCommands.register_subcommand(commands_parser)
|
UserCommands.register_subcommand(commands_parser)
|
||||||
AddNewModelLikeCommand.register_subcommand(commands_parser)
|
AddNewModelLikeCommand.register_subcommand(commands_parser)
|
||||||
LfsCommands.register_subcommand(commands_parser)
|
LfsCommands.register_subcommand(commands_parser)
|
||||||
PTtoTFCommand.register_subcommand(commands_parser)
|
|
||||||
|
|
||||||
# Let's go
|
# Let's go
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|||||||
@@ -351,7 +351,6 @@ src/transformers/commands/convert.py
|
|||||||
src/transformers/commands/download.py
|
src/transformers/commands/download.py
|
||||||
src/transformers/commands/env.py
|
src/transformers/commands/env.py
|
||||||
src/transformers/commands/lfs.py
|
src/transformers/commands/lfs.py
|
||||||
src/transformers/commands/pt_to_tf.py
|
|
||||||
src/transformers/commands/run.py
|
src/transformers/commands/run.py
|
||||||
src/transformers/commands/serving.py
|
src/transformers/commands/serving.py
|
||||||
src/transformers/commands/train.py
|
src/transformers/commands/train.py
|
||||||
|
|||||||
Reference in New Issue
Block a user