Templates overhaul 1 (#8993)

This commit is contained in:
Lysandre Debut
2020-12-08 18:00:07 -05:00
committed by GitHub
parent 447808c85f
commit 67ff1c314a
16 changed files with 759 additions and 82 deletions

View File

@@ -19,12 +19,18 @@ from argparse import ArgumentParser, Namespace
from pathlib import Path
from typing import List
from cookiecutter.main import cookiecutter
from transformers.commands import BaseTransformersCLICommand
from ..utils import logging
try:
from cookiecutter.main import cookiecutter
_has_cookiecutter = True
except ImportError:
_has_cookiecutter = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -49,6 +55,11 @@ class AddNewModelCommand(BaseTransformersCLICommand):
self._path = path
def run(self):
if not _has_cookiecutter:
raise ImportError(
"Model creation dependencies are required to use the `add_new_model` command. Install them by running "
"the folowing at the root of your `transformers` clone:\n\n\t$ pip install -e .[modelcreation]\n"
)
# Ensure that there is no other `cookiecutter-template-xxx` directory in the current working directory
directories = [directory for directory in os.listdir() if "cookiecutter-template-" == directory[:22]]
if len(directories) > 0:
@@ -153,6 +164,11 @@ class AddNewModelCommand(BaseTransformersCLICommand):
f"{model_dir}/tokenization_{lowercase_model_name}.py",
)
shutil.move(
f"{directory}/tokenization_fast_{lowercase_model_name}.py",
f"{model_dir}/tokenization_{lowercase_model_name}_fast.py",
)
from os import fdopen, remove
from shutil import copymode, move
from tempfile import mkstemp