Model templates encoder only (#8509)
* Model templates * TensorFlow * Remove pooler * CI * Tokenizer + Refactoring * Encoder-Decoder * Let's go testing * Encoder-Decoder in TF * Let's go testing in TF * Documentation * README * Fixes * Better names * Style * Update docs * Choose to skip either TF or PT * Code quality fixes * Add to testing suite * Update file path * Cookiecutter path * Update `transformers` path * Handle rebasing * Remove seq2seq from model templates * Remove s2s config * Apply Sylvain and Patrick comments * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Last fixes from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
192
src/transformers/commands/add_new_model.py
Normal file
192
src/transformers/commands/add_new_model.py
Normal file
@@ -0,0 +1,192 @@
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
from argparse import ArgumentParser, Namespace
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from cookiecutter.main import cookiecutter
|
||||
from transformers.commands import BaseTransformersCLICommand
|
||||
|
||||
from ..utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def add_new_model_command_factory(args: Namespace):
|
||||
return AddNewModelCommand(args.testing, args.testing_file, path=args.path)
|
||||
|
||||
|
||||
class AddNewModelCommand(BaseTransformersCLICommand):
|
||||
@staticmethod
|
||||
def register_subcommand(parser: ArgumentParser):
|
||||
add_new_model_parser = parser.add_parser("add-new-model")
|
||||
add_new_model_parser.add_argument("--testing", action="store_true", help="If in testing mode.")
|
||||
add_new_model_parser.add_argument("--testing_file", type=str, help="Configuration file on which to run.")
|
||||
add_new_model_parser.add_argument(
|
||||
"--path", type=str, help="Path to cookiecutter. Should only be used for testing purposes."
|
||||
)
|
||||
add_new_model_parser.set_defaults(func=add_new_model_command_factory)
|
||||
|
||||
def __init__(self, testing: bool, testing_file: str, path=None, *args):
|
||||
self._testing = testing
|
||||
self._testing_file = testing_file
|
||||
self._path = path
|
||||
|
||||
def run(self):
|
||||
# Ensure that there is no other `cookiecutter-template-xxx` directory in the current working directory
|
||||
directories = [directory for directory in os.listdir() if "cookiecutter-template-" == directory[:22]]
|
||||
if len(directories) > 0:
|
||||
raise ValueError(
|
||||
"Several directories starting with `cookiecutter-template-` in current working directory. "
|
||||
"Please clean your directory by removing all folders startign with `cookiecutter-template-` or "
|
||||
"change your working directory."
|
||||
)
|
||||
|
||||
path_to_transformer_root = (
|
||||
Path(__file__).parent.parent.parent.parent if self._path is None else Path(self._path).parent.parent
|
||||
)
|
||||
path_to_cookiecutter = path_to_transformer_root / "templates" / "cookiecutter"
|
||||
|
||||
# Execute cookiecutter
|
||||
if not self._testing:
|
||||
cookiecutter(str(path_to_cookiecutter))
|
||||
else:
|
||||
with open(self._testing_file, "r") as configuration_file:
|
||||
testing_configuration = json.load(configuration_file)
|
||||
|
||||
cookiecutter(
|
||||
str(path_to_cookiecutter if self._path is None else self._path),
|
||||
no_input=True,
|
||||
extra_context=testing_configuration,
|
||||
)
|
||||
|
||||
directory = [directory for directory in os.listdir() if "cookiecutter-template-" in directory[:22]][0]
|
||||
|
||||
# Retrieve configuration
|
||||
with open(directory + "/configuration.json", "r") as configuration_file:
|
||||
configuration = json.load(configuration_file)
|
||||
|
||||
lowercase_model_name = configuration["lowercase_modelname"]
|
||||
pytorch_or_tensorflow = configuration["generate_tensorflow_and_pytorch"]
|
||||
os.remove(f"{directory}/configuration.json")
|
||||
|
||||
output_pytorch = "PyTorch" in pytorch_or_tensorflow
|
||||
output_tensorflow = "TensorFlow" in pytorch_or_tensorflow
|
||||
|
||||
shutil.move(
|
||||
f"{directory}/configuration_{lowercase_model_name}.py",
|
||||
f"{path_to_transformer_root}/src/transformers/configuration_{lowercase_model_name}.py",
|
||||
)
|
||||
|
||||
def remove_copy_lines(path):
|
||||
with open(path, "r") as f:
|
||||
lines = f.readlines()
|
||||
with open(path, "w") as f:
|
||||
for line in lines:
|
||||
if "# Copied from transformers." not in line:
|
||||
f.write(line)
|
||||
|
||||
if output_pytorch:
|
||||
if not self._testing:
|
||||
remove_copy_lines(f"{directory}/modeling_{lowercase_model_name}.py")
|
||||
|
||||
shutil.move(
|
||||
f"{directory}/modeling_{lowercase_model_name}.py",
|
||||
f"{path_to_transformer_root}/src/transformers/modeling_{lowercase_model_name}.py",
|
||||
)
|
||||
|
||||
shutil.move(
|
||||
f"{directory}/test_modeling_{lowercase_model_name}.py",
|
||||
f"{path_to_transformer_root}/tests/test_modeling_{lowercase_model_name}.py",
|
||||
)
|
||||
else:
|
||||
os.remove(f"{directory}/modeling_{lowercase_model_name}.py")
|
||||
os.remove(f"{directory}/test_modeling_{lowercase_model_name}.py")
|
||||
|
||||
if output_tensorflow:
|
||||
if not self._testing:
|
||||
remove_copy_lines(f"{directory}/modeling_tf_{lowercase_model_name}.py")
|
||||
|
||||
shutil.move(
|
||||
f"{directory}/modeling_tf_{lowercase_model_name}.py",
|
||||
f"{path_to_transformer_root}/src/transformers/modeling_tf_{lowercase_model_name}.py",
|
||||
)
|
||||
|
||||
shutil.move(
|
||||
f"{directory}/test_modeling_tf_{lowercase_model_name}.py",
|
||||
f"{path_to_transformer_root}/tests/test_modeling_tf_{lowercase_model_name}.py",
|
||||
)
|
||||
else:
|
||||
os.remove(f"{directory}/modeling_tf_{lowercase_model_name}.py")
|
||||
os.remove(f"{directory}/test_modeling_tf_{lowercase_model_name}.py")
|
||||
|
||||
shutil.move(
|
||||
f"{directory}/{lowercase_model_name}.rst",
|
||||
f"{path_to_transformer_root}/docs/source/model_doc/{lowercase_model_name}.rst",
|
||||
)
|
||||
|
||||
shutil.move(
|
||||
f"{directory}/tokenization_{lowercase_model_name}.py",
|
||||
f"{path_to_transformer_root}/src/transformers/tokenization_{lowercase_model_name}.py",
|
||||
)
|
||||
|
||||
from os import fdopen, remove
|
||||
from shutil import copymode, move
|
||||
from tempfile import mkstemp
|
||||
|
||||
def replace(original_file: str, line_to_copy_below: str, lines_to_copy: List[str]):
|
||||
# Create temp file
|
||||
fh, abs_path = mkstemp()
|
||||
line_found = False
|
||||
with fdopen(fh, "w") as new_file:
|
||||
with open(original_file) as old_file:
|
||||
for line in old_file:
|
||||
new_file.write(line)
|
||||
if line_to_copy_below in line:
|
||||
line_found = True
|
||||
for line_to_copy in lines_to_copy:
|
||||
new_file.write(line_to_copy)
|
||||
|
||||
if not line_found:
|
||||
raise ValueError(f"Line {line_to_copy_below} was not found in file.")
|
||||
|
||||
# Copy the file permissions from the old file to the new file
|
||||
copymode(original_file, abs_path)
|
||||
# Remove original file
|
||||
remove(original_file)
|
||||
# Move new file
|
||||
move(abs_path, original_file)
|
||||
|
||||
def skip_units(line):
|
||||
return ("generating PyTorch" in line and not output_pytorch) or (
|
||||
"generating TensorFlow" in line and not output_tensorflow
|
||||
)
|
||||
|
||||
def replace_in_files(path_to_datafile):
|
||||
with open(path_to_datafile) as datafile:
|
||||
lines_to_copy = []
|
||||
skip_file = False
|
||||
skip_snippet = False
|
||||
for line in datafile:
|
||||
if "# To replace in: " in line and "##" not in line:
|
||||
file_to_replace_in = line.split('"')[1]
|
||||
skip_file = skip_units(line)
|
||||
elif "# Below: " in line and "##" not in line:
|
||||
line_to_copy_below = line.split('"')[1]
|
||||
skip_snippet = skip_units(line)
|
||||
elif "# End." in line and "##" not in line:
|
||||
if not skip_file and not skip_snippet:
|
||||
replace(file_to_replace_in, line_to_copy_below, lines_to_copy)
|
||||
|
||||
lines_to_copy = []
|
||||
elif "# Replace with" in line and "##" not in line:
|
||||
lines_to_copy = []
|
||||
elif "##" not in line:
|
||||
lines_to_copy.append(line)
|
||||
|
||||
remove(path_to_datafile)
|
||||
|
||||
replace_in_files(f"{directory}/to_replace_{lowercase_model_name}.py")
|
||||
os.rmdir(directory)
|
||||
@@ -1,6 +1,7 @@
|
||||
#!/usr/bin/env python
|
||||
from argparse import ArgumentParser
|
||||
|
||||
from transformers.commands.add_new_model import AddNewModelCommand
|
||||
from transformers.commands.convert import ConvertCommand
|
||||
from transformers.commands.download import DownloadCommand
|
||||
from transformers.commands.env import EnvironmentCommand
|
||||
@@ -20,6 +21,7 @@ def main():
|
||||
RunCommand.register_subcommand(commands_parser)
|
||||
ServeCommand.register_subcommand(commands_parser)
|
||||
UserCommands.register_subcommand(commands_parser)
|
||||
AddNewModelCommand.register_subcommand(commands_parser)
|
||||
|
||||
# Let's go
|
||||
args = parser.parse_args()
|
||||
|
||||
Reference in New Issue
Block a user